1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/PointwiseOps.h>
10 #include <c10/core/Scalar.h>
11
12 namespace at::native {
13
14 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
15 CONSTEXPR_EXCEPT_WIN_CUDA char addcmul_name[] = "addcmul";
16 #endif
addcmul_cuda_kernel(TensorIteratorBase & iter,const Scalar & value)17 void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
18 auto dtype = iter.common_dtype();
19 if (at::isComplexType(dtype)) {
20 // When using Jiterator, addcmul and addcdiv kernels get stuck during a
21 // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
22 // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
23 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
24 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
25 auto alpha = value.to<scalar_t>();
26 static const auto addcmul_string = jiterator_stringify(
27 template <typename T> T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
28 jitted_gpu_kernel<
29 /*name=*/addcmul_name,
30 /*return_dtype=*/scalar_t,
31 /*common_dtype=*/scalar_t,
32 /*arity=*/3>(
33 iter,
34 addcmul_string,
35 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
36 /*scalar_val=*/0,
37 /*extra_args=*/std::make_tuple(alpha));
38 });
39 #else
40 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
41 auto alpha = value.to<scalar_t>();
42 gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
43 return a + alpha * b * c;
44 });
45 });
46 #endif
47 } else {
48 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
49 // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
50 // and do math in fp32 for better accuracy.
51 using accscalar_t = at::acc_type<scalar_t, true>;
52 auto alpha = value.to<accscalar_t>();
53 gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
54 return a + alpha * (static_cast<accscalar_t>(b) * static_cast<accscalar_t>(c));
55 });
56 });
57 }
58 }
59
60 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
61 // return a + alpha * (b / static_cast<accscalar_t>(c));
62 CONSTEXPR_EXCEPT_WIN_CUDA char addcdiv_name[] = "addcdiv";
63 #endif
addcdiv_cuda_kernel(TensorIteratorBase & iter,const Scalar & value)64 void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
65 auto dtype = iter.common_dtype();
66 if (at::isComplexType(dtype)) {
67 // When using Jiterator, addcmul and addcdiv kernels get stuck during a
68 // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
69 // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
70 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
71 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() {
72 auto alpha = value.to<scalar_t>();
73 static const auto addcdiv_string =
74 jiterator_stringify(template <typename T> T addcdiv(
75 T a, T b, T c, T alpha) { return a + alpha * (b / c); });
76 jitted_gpu_kernel<
77 /*name=*/addcdiv_name,
78 /*return_dtype=*/scalar_t,
79 /*common_dtype=*/scalar_t,
80 /*arity=*/3>(
81 iter,
82 addcdiv_string,
83 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
84 /*scalar_val=*/0,
85 /*extra_args=*/std::make_tuple(alpha));
86 });
87 #else
88 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() {
89 auto alpha = value.to<scalar_t>();
90 gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
91 return a + alpha * (b / c);
92 });
93 });
94 #endif
95 } else {
96 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcdiv_cuda", [&]() {
97 // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
98 // and do math in fp32 for better accuracy.
99 using accscalar_t = at::acc_type<scalar_t, true>;
100 auto alpha = value.to<accscalar_t>();
101 gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
102 return a + alpha * (b / static_cast<accscalar_t>(c));
103 });
104 });
105 }
106 }
107
smooth_l1_backward_cuda_kernel(TensorIterator & iter,const Scalar & norm,double beta)108 void smooth_l1_backward_cuda_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
109 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "smooth_l1_backward_cuda", [&iter, &norm, beta] {
110 auto norm_val = norm.to<scalar_t>();
111 scalar_t beta_val(beta);
112 gpu_kernel(iter, [norm_val, beta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
113 const auto x = input - target;
114 if (x < -beta_val)
115 return -norm_val * grad_output;
116 else if (x > beta_val)
117 return norm_val * grad_output;
118 else
119 return norm_val * x * grad_output / beta_val;
120 });
121 });
122 }
123
huber_backward_cuda_kernel(TensorIterator & iter,const Scalar & norm,double delta)124 void huber_backward_cuda_kernel(TensorIterator& iter, const Scalar& norm, double delta) {
125 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "huber_backward_cuda", [&iter, &norm, delta] {
126 auto norm_val = norm.to<scalar_t>();
127 scalar_t delta_val(delta);
128 gpu_kernel(iter, [norm_val, delta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
129 const auto x = input - target;
130 if (x < -delta_val) {
131 return -norm_val * grad_output * delta_val;
132 } else if (x > delta_val) {
133 return norm_val * grad_output * delta_val;
134 } else {
135 return norm_val * x * grad_output;
136 }
137 });
138 });
139 }
140
mse_backward_cuda_kernel(TensorIterator & iter,const Scalar & value)141 void mse_backward_cuda_kernel(TensorIterator& iter, const Scalar& value) {
142 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "mse_backward_cuda", [&]() {
143 auto alpha = value.to<scalar_t>();
144 gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
145 return alpha * (a - b) * c;
146 });
147 });
148 }
149
150 REGISTER_DISPATCH(addcdiv_stub, &addcdiv_cuda_kernel);
151 REGISTER_DISPATCH(addcmul_stub, &addcmul_cuda_kernel);
152 REGISTER_DISPATCH(smooth_l1_backward_stub, &smooth_l1_backward_cuda_kernel);
153 REGISTER_DISPATCH(huber_backward_stub, &huber_backward_cuda_kernel);
154 REGISTER_DISPATCH(mse_backward_stub, &mse_backward_cuda_kernel);
155 } // namespace at::native
156