1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <ATen/native/cuda/Math.cuh>
8 #include <ATen/NumericUtils.h>
9
10 // NOTE: CUDA on Windows requires that the enclosing function
11 // of a __device__ lambda not have internal linkage.
12
13 namespace at::native {
14
smooth_l1_kernel_cuda(TensorIteratorBase & iter,double beta)15 void smooth_l1_kernel_cuda(TensorIteratorBase& iter, double beta) {
16 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "smooth_l1_cuda", [&iter, beta]() {
17 scalar_t beta_val(beta);
18 gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
19 auto z = ::abs(a - b);
20 return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val;
21 });
22 });
23 }
24
huber_kernel_cuda(TensorIterator & iter,double delta)25 void huber_kernel_cuda(TensorIterator& iter, double delta) {
26 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "huber_cuda", [&iter, delta] {
27 scalar_t delta_val(delta);
28 gpu_kernel(iter, [delta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
29 auto z = ::abs(a - b);
30 return z < delta_val ? scalar_t(0.5) * z * z : delta_val * (z - scalar_t(0.5) * delta_val);
31 });
32 });
33 }
34
mse_kernel_cuda(TensorIteratorBase & iter)35 void mse_kernel_cuda(TensorIteratorBase& iter) {
36 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "mse_cuda", [&]() {
37 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
38 auto diff = a - b;
39 return diff * diff;
40 });
41 });
42 }
43
xlogy_kernel_cuda(TensorIteratorBase & iter)44 void xlogy_kernel_cuda(TensorIteratorBase& iter) {
45 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
46 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
47 if (at::_isnan(y)){
48 return NAN;
49 }
50 if (x == 0){
51 return 0;
52 }
53 return x * std::log(y);
54 });
55 });
56 }
57
xlog1py_kernel_cuda(TensorIteratorBase & iter)58 void xlog1py_kernel_cuda(TensorIteratorBase& iter) {
59 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlog1py_cuda", [&]() {
60 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
61 if (at::_isnan(y)){
62 return NAN;
63 }
64 if (x == 0){
65 return 0;
66 }
67 return x * std::log1p(y);
68 });
69 });
70 }
71
72 REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
73 REGISTER_DISPATCH(huber_stub, &huber_kernel_cuda);
74 REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
75 REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
76 REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel_cuda);
77
78 // DO NOT ADD ANY NEW KERNELS HERE
79 // CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
80
81 } // namespace at::native
82