xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/StepKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <c10/util/BFloat16-math.h>
8 
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11 
12 namespace at::native {
13 
nextafter_kernel_cuda(TensorIteratorBase & iter)14 void nextafter_kernel_cuda(TensorIteratorBase& iter) {
15   AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.common_dtype(), "nextafter_cuda", [&]() {
16     gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
17       return std::nextafter(a, b);
18     });
19   });
20 }
21 
heaviside_kernel_cuda(TensorIteratorBase & iter)22 void heaviside_kernel_cuda(TensorIteratorBase& iter) {
23   AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
24     gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
25       return a == 0 ? b : static_cast<scalar_t>(a > 0);
26     });
27   });
28 }
29 
30 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
31 REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
32 
33 } // namespace at::native
34