1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18
19 namespace at::native {
20 namespace {
21
22 template <typename scalar_t>
threshold_kernel_impl(TensorIteratorBase & iter,scalar_t threshold,scalar_t value)23 void threshold_kernel_impl(
24 TensorIteratorBase& iter,
25 scalar_t threshold,
26 scalar_t value) {
27 gpu_kernel_with_scalars(
28 iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
29 return x <= threshold ? value : other;
30 });
31 }
32
threshold_kernel_cuda(TensorIteratorBase & iter,const Scalar & threshold,const Scalar & value)33 static void threshold_kernel_cuda(
34 TensorIteratorBase& iter,
35 const Scalar& threshold,
36 const Scalar& value) {
37 AT_DISPATCH_ALL_TYPES_AND2(
38 at::ScalarType::Half,
39 at::ScalarType::BFloat16,
40 iter.dtype(),
41 "threshold_cuda",
42 [&] {
43 threshold_kernel_impl<scalar_t>(
44 iter, threshold.to<scalar_t>(), value.to<scalar_t>());
45 });
46 }
47
48 } // namespace
49
50 REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);
51
52 } // namespace at::native
53