xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/PointwiseOps.h>
10 #include <c10/core/Scalar.h>
11 
12 namespace at::native {
13 
14 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
15 CONSTEXPR_EXCEPT_WIN_CUDA char addcmul_name[] = "addcmul";
16 #endif
addcmul_cuda_kernel(TensorIteratorBase & iter,const Scalar & value)17 void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
18   auto dtype = iter.common_dtype();
19   if (at::isComplexType(dtype)) {
20     // When using Jiterator, addcmul and addcdiv kernels get stuck during a
21     // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
22     // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
23     #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
24       AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
25         auto alpha = value.to<scalar_t>();
26         static const auto addcmul_string = jiterator_stringify(
27           template <typename T> T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
28         jitted_gpu_kernel<
29             /*name=*/addcmul_name,
30             /*return_dtype=*/scalar_t,
31             /*common_dtype=*/scalar_t,
32             /*arity=*/3>(
33             iter,
34             addcmul_string,
35             /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
36             /*scalar_val=*/0,
37             /*extra_args=*/std::make_tuple(alpha));
38       });
39     #else
40       AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
41         auto alpha = value.to<scalar_t>();
42         gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
43           return a + alpha * b * c;
44         });
45       });
46     #endif
47   } else {
48     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
49       // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
50       // and do math in fp32 for better accuracy.
51       using accscalar_t = at::acc_type<scalar_t, true>;
52       auto alpha = value.to<accscalar_t>();
53       gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
54         return a + alpha * (static_cast<accscalar_t>(b) * static_cast<accscalar_t>(c));
55       });
56     });
57   }
58 }
59 
60 #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
61 // return a + alpha * (b / static_cast<accscalar_t>(c));
62 CONSTEXPR_EXCEPT_WIN_CUDA char addcdiv_name[] = "addcdiv";
63 #endif
addcdiv_cuda_kernel(TensorIteratorBase & iter,const Scalar & value)64 void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
65   auto dtype = iter.common_dtype();
66   if (at::isComplexType(dtype)) {
67     // When using Jiterator, addcmul and addcdiv kernels get stuck during a
68     // promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
69     // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
70     #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
71       AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() {
72         auto alpha = value.to<scalar_t>();
73         static const auto addcdiv_string =
74             jiterator_stringify(template <typename T> T addcdiv(
75                 T a, T b, T c, T alpha) { return a + alpha * (b / c); });
76         jitted_gpu_kernel<
77             /*name=*/addcdiv_name,
78             /*return_dtype=*/scalar_t,
79             /*common_dtype=*/scalar_t,
80             /*arity=*/3>(
81             iter,
82             addcdiv_string,
83             /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
84             /*scalar_val=*/0,
85             /*extra_args=*/std::make_tuple(alpha));
86       });
87     #else
88       AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() {
89         auto alpha = value.to<scalar_t>();
90         gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
91           return a + alpha * (b / c);
92         });
93       });
94     #endif
95   } else {
96     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcdiv_cuda", [&]() {
97       // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
98       // and do math in fp32 for better accuracy.
99       using accscalar_t = at::acc_type<scalar_t, true>;
100       auto alpha = value.to<accscalar_t>();
101       gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
102         return a + alpha * (b / static_cast<accscalar_t>(c));
103       });
104     });
105   }
106 }
107 
smooth_l1_backward_cuda_kernel(TensorIterator & iter,const Scalar & norm,double beta)108 void smooth_l1_backward_cuda_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
109   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "smooth_l1_backward_cuda", [&iter, &norm, beta] {
110       auto norm_val = norm.to<scalar_t>();
111       scalar_t beta_val(beta);
112       gpu_kernel(iter, [norm_val, beta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
113         const auto x = input - target;
114         if (x < -beta_val)
115           return -norm_val * grad_output;
116         else if (x > beta_val)
117           return norm_val * grad_output;
118         else
119           return norm_val * x * grad_output / beta_val;
120     });
121   });
122 }
123 
huber_backward_cuda_kernel(TensorIterator & iter,const Scalar & norm,double delta)124 void huber_backward_cuda_kernel(TensorIterator& iter, const Scalar& norm, double delta) {
125   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "huber_backward_cuda", [&iter, &norm, delta] {
126     auto norm_val = norm.to<scalar_t>();
127     scalar_t delta_val(delta);
128     gpu_kernel(iter, [norm_val, delta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
129       const auto x = input - target;
130       if (x < -delta_val) {
131         return -norm_val * grad_output * delta_val;
132       } else if (x > delta_val) {
133         return norm_val * grad_output * delta_val;
134       } else {
135         return norm_val * x * grad_output;
136       }
137     });
138   });
139 }
140 
mse_backward_cuda_kernel(TensorIterator & iter,const Scalar & value)141 void mse_backward_cuda_kernel(TensorIterator& iter, const Scalar& value) {
142   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "mse_backward_cuda", [&]() {
143     auto alpha = value.to<scalar_t>();
144     gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
145       return alpha * (a - b) * c;
146     });
147   });
148 }
149 
150 REGISTER_DISPATCH(addcdiv_stub, &addcdiv_cuda_kernel);
151 REGISTER_DISPATCH(addcmul_stub, &addcmul_cuda_kernel);
152 REGISTER_DISPATCH(smooth_l1_backward_stub, &smooth_l1_backward_cuda_kernel);
153 REGISTER_DISPATCH(huber_backward_stub, &huber_backward_cuda_kernel);
154 REGISTER_DISPATCH(mse_backward_stub, &mse_backward_cuda_kernel);
155 } // namespace at::native
156