1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/BinaryOps.h>
3
4 #include <limits>
5
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/Loops.cuh>
11 #include <ATen/native/cuda/JitLoops.cuh>
12
13 // NOTE: CUDA on Windows requires that the enclosing function
14 // of a __device__ lambda not have internal linkage.
15
16 namespace at::native {
17
18 CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_backward_name[] = "sigmoid_backward";
sigmoid_backward_kernel_cuda(TensorIteratorBase & iter)19 void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) {
20 auto dtype = iter.dtype();
21 if(isComplexType(dtype)) {
22 #if AT_USE_JITERATOR()
23 static const auto sigmoid_backward_string = jiterator_stringify(
24 template <typename T>
25 T sigmoid_backward(T a, T b) {
26 return a * std::conj((T{1.} - b) * b);
27 }
28 ); // sigmoid_backward_string
29 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_cuda", [&]() {
30 jitted_gpu_kernel<
31 /*name=*/ sigmoid_backward_name,
32 /*return_dtype=*/ scalar_t,
33 /*common_dtype=*/ scalar_t,
34 /*arity=*/ 2>(iter, sigmoid_backward_string);
35 });
36 #else
37 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_cuda", [&]() {
38 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
39 using comp_t = at::opmath_type<scalar_t>;
40 const auto one = comp_t{1.};
41 const auto comp_b = static_cast<comp_t>(b);
42 const auto comp_a = static_cast<comp_t>(a);
43 return static_cast<scalar_t>(comp_a * std::conj((one - comp_b) * comp_b));
44 });
45 });
46 #endif
47 } else {
48 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "sigmoid_backward_cuda", [&]() {
49 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
50 return a * (scalar_t(1.) - b) * b;
51 });
52 });
53 }
54 }
55
logit_backward_kernel_cuda(TensorIteratorBase & iter,const Scalar & eps_scalar)56 void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
57 AT_DISPATCH_FLOATING_TYPES_AND2(
58 at::ScalarType::Half,
59 at::ScalarType::BFloat16,
60 iter.dtype(),
61 "logit_cuda",
62 [&]() {
63 using T_ACC = acc_type<scalar_t, true>;
64 const T_ACC eps = eps_scalar.to<T_ACC>();
65 if (eps < T_ACC(0)) {
66 gpu_kernel(
67 iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
68 const T_ACC dy_acc = static_cast<T_ACC>(dy);
69 const T_ACC x_acc = static_cast<T_ACC>(x);
70 return (x_acc < T_ACC(0) || x_acc > T_ACC(1))
71 ? std::numeric_limits<T_ACC>::quiet_NaN()
72 : dy_acc / (x_acc * (T_ACC(1) - x_acc));
73 });
74 } else {
75 const T_ACC lo = eps;
76 const T_ACC hi = T_ACC(1) - eps;
77 gpu_kernel(
78 iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
79 const T_ACC dy_acc = static_cast<T_ACC>(dy);
80 const T_ACC x_acc = static_cast<T_ACC>(x);
81 return (x_acc < lo || x_acc > hi)
82 ? T_ACC(0)
83 : dy_acc / (x_acc * (T_ACC(1) - x_acc));
84 });
85 }
86 });
87 }
88
89 CONSTEXPR_EXCEPT_WIN_CUDA char tanh_backward_name[] = "tanh_backward";
tanh_backward_kernel_cuda(TensorIteratorBase & iter)90 void tanh_backward_kernel_cuda(TensorIteratorBase& iter) {
91 auto dtype = iter.dtype();
92 if(isComplexType(dtype)) {
93 #if AT_USE_JITERATOR()
94 static const auto tanh_backward_string = jiterator_stringify(
95 template <typename T>
96 T tanh_backward(T a, T b) {
97 return a * std::conj(T{1.} - b * b);
98 }
99 ); // tanh_backward_string
100 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
101 jitted_gpu_kernel<
102 /*name=*/ tanh_backward_name,
103 /*return_dtype=*/ scalar_t,
104 /*common_dtype=*/ scalar_t,
105 /*arity=*/ 2>(iter, tanh_backward_string);
106 });
107 #else
108 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_cuda", [&]() {
109 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
110 using comp_t = at::opmath_type<scalar_t>;
111 const auto one = comp_t{1.};
112 const auto comp_b = static_cast<comp_t>(b);
113 const auto comp_a = static_cast<comp_t>(a);
114 return static_cast<scalar_t>(comp_a * std::conj(one - comp_b * comp_b));
115 });
116 });
117 #endif
118 } else {
119 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "tanh_backward_cuda", [&]() {
120 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
121 return a * (scalar_t{1.} - b * b);
122 });
123 });
124 }
125 }
126
127 REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_cuda);
128 REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel_cuda);
129 REGISTER_DISPATCH(tanh_backward_stub, &tanh_backward_kernel_cuda);
130
131 } // namespace at::native
132