xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/BinaryOps.h>
3 
4 #include <limits>
5 
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/cuda/JitLoops.cuh>
12 
13 // NOTE: CUDA on Windows requires that the enclosing function
14 // of a __device__ lambda not have internal linkage.
15 
16 namespace at::native {
17 
18 CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_backward_name[] = "sigmoid_backward";
sigmoid_backward_kernel_cuda(TensorIteratorBase & iter)19 void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) {
20   auto dtype = iter.dtype();
21   if(isComplexType(dtype)) {
22 #if AT_USE_JITERATOR()
23     static const auto sigmoid_backward_string = jiterator_stringify(
24         template <typename T>
25         T sigmoid_backward(T a, T b) {
26           return a * std::conj((T{1.} - b) * b);
27         }
28     ); // sigmoid_backward_string
29     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_cuda", [&]() {
30         jitted_gpu_kernel<
31           /*name=*/ sigmoid_backward_name,
32           /*return_dtype=*/ scalar_t,
33           /*common_dtype=*/ scalar_t,
34           /*arity=*/ 2>(iter, sigmoid_backward_string);
35     });
36 #else
37     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_cuda", [&]() {
38       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
39         using comp_t = at::opmath_type<scalar_t>;
40         const auto one = comp_t{1.};
41         const auto comp_b = static_cast<comp_t>(b);
42         const auto comp_a = static_cast<comp_t>(a);
43         return static_cast<scalar_t>(comp_a * std::conj((one - comp_b) * comp_b));
44       });
45     });
46 #endif
47   } else {
48     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "sigmoid_backward_cuda", [&]() {
49       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
50         return a * (scalar_t(1.) - b) * b;
51       });
52     });
53   }
54 }
55 
logit_backward_kernel_cuda(TensorIteratorBase & iter,const Scalar & eps_scalar)56 void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
57   AT_DISPATCH_FLOATING_TYPES_AND2(
58       at::ScalarType::Half,
59       at::ScalarType::BFloat16,
60       iter.dtype(),
61       "logit_cuda",
62       [&]() {
63         using T_ACC = acc_type<scalar_t, true>;
64         const T_ACC eps = eps_scalar.to<T_ACC>();
65         if (eps < T_ACC(0)) {
66           gpu_kernel(
67               iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
68                 const T_ACC dy_acc = static_cast<T_ACC>(dy);
69                 const T_ACC x_acc = static_cast<T_ACC>(x);
70                 return (x_acc < T_ACC(0) || x_acc > T_ACC(1))
71                     ? std::numeric_limits<T_ACC>::quiet_NaN()
72                     : dy_acc / (x_acc * (T_ACC(1) - x_acc));
73               });
74         } else {
75           const T_ACC lo = eps;
76           const T_ACC hi = T_ACC(1) - eps;
77           gpu_kernel(
78               iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
79                 const T_ACC dy_acc = static_cast<T_ACC>(dy);
80                 const T_ACC x_acc = static_cast<T_ACC>(x);
81                 return (x_acc < lo || x_acc > hi)
82                     ? T_ACC(0)
83                     : dy_acc / (x_acc * (T_ACC(1) - x_acc));
84               });
85         }
86       });
87 }
88 
89 CONSTEXPR_EXCEPT_WIN_CUDA char tanh_backward_name[] = "tanh_backward";
tanh_backward_kernel_cuda(TensorIteratorBase & iter)90 void tanh_backward_kernel_cuda(TensorIteratorBase& iter) {
91   auto dtype = iter.dtype();
92   if(isComplexType(dtype)) {
93 #if AT_USE_JITERATOR()
94     static const auto tanh_backward_string = jiterator_stringify(
95       template <typename T>
96       T tanh_backward(T a, T b) {
97         return a * std::conj(T{1.} - b * b);
98       }
99     ); // tanh_backward_string
100     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
101       jitted_gpu_kernel<
102           /*name=*/ tanh_backward_name,
103           /*return_dtype=*/ scalar_t,
104           /*common_dtype=*/ scalar_t,
105           /*arity=*/ 2>(iter, tanh_backward_string);
106     });
107 #else
108     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
109       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
110         using comp_t = at::opmath_type<scalar_t>;
111         const auto one = comp_t{1.};
112         const auto comp_b = static_cast<comp_t>(b);
113         const auto comp_a = static_cast<comp_t>(a);
114         return static_cast<scalar_t>(comp_a * std::conj(one - comp_b * comp_b));
115       });
116     });
117 #endif
118   } else {
119     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "tanh_backward_cuda", [&]() {
120       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
121         return a * (scalar_t{1.} - b * b);
122       });
123     });
124   }
125 }
126 
127 REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_cuda);
128 REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel_cuda);
129 REGISTER_DISPATCH(tanh_backward_stub, &tanh_backward_kernel_cuda);
130 
131 } // namespace at::native
132