1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/BinaryOps.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <c10/util/TypeSafeSignMath.h>
8
9 #include <type_traits>
10
11 // NOTE: CUDA on Windows requires that the enclosing function
12 // of a __device__ lambda not have internal linkage.
13
14 namespace at::native {
15
remainder_kernel_cuda(TensorIteratorBase & iter)16 void remainder_kernel_cuda(TensorIteratorBase& iter) {
17 if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) {
18 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cuda", [&]() {
19 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
20 scalar_t r = a % b;
21 if (r != 0 && c10::signs_differ(r, b)) {
22 r += b;
23 }
24 return r;
25 });
26 });
27 } else {
28 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "remainder_cuda", [&]() {
29 gpu_kernel_with_scalars(iter,
30 []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
31 auto mod = ::fmod(a, b);
32 if (mod != 0 && c10::signs_differ(b, mod)) {
33 mod += b;
34 }
35 return mod;
36 });
37 });
38 }
39 }
40
fmod_kernel_cuda(TensorIteratorBase & iter)41 void fmod_kernel_cuda(TensorIteratorBase& iter) {
42 if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) {
43 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cuda", [&]() {
44 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
45 return a % b;
46 });
47 });
48 } else {
49 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "fmod_cuda", [&]() {
50 gpu_kernel_with_scalars(iter,
51 []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
52 return ::fmod(a, b);
53 });
54 });
55 }
56 }
57
58 REGISTER_DISPATCH(remainder_stub, &remainder_kernel_cuda);
59 REGISTER_DISPATCH(fmod_stub, &fmod_kernel_cuda);
60
61 } // namespace at::native
62