1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/OpMathType.h>
8
9 namespace at::native {
10 namespace {
11
12 CONSTEXPR_EXCEPT_WIN_CUDA char lerp_tensor_name[] = "lerp_tensor";
lerp_tensor_kernel(at::TensorIteratorBase & iter)13 void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
14 auto dtype = iter.common_dtype();
15 if(at::isComplexType(dtype)) {
16 #if AT_USE_JITERATOR()
17 static const auto lerp_tensor_string = jiterator_stringify(
18 template <typename T>
19 T lerp_tensor(T self_val, T end_val, T weight_val) {
20 return (std::abs(weight_val) < 0.5)
21 ? self_val + weight_val * (end_val - self_val)
22 : end_val -
23 (end_val - self_val) * (static_cast<T>(1) - weight_val);
24 }
25 ); // lerp_tensor_string
26 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
27 jitted_gpu_kernel<
28 /*name=*/ lerp_tensor_name,
29 /*return_dtype=*/ scalar_t,
30 /*common_dtype=*/ scalar_t,
31 /*arity=*/ 3>(iter, lerp_tensor_string);
32 });
33 #else
34 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
35 using opmath_t = at::opmath_type<scalar_t>;
36 at::native::gpu_kernel(
37 iter,
38 [] GPU_LAMBDA(
39 scalar_t self_val,
40 scalar_t end_val,
41 scalar_t weight_val) -> scalar_t {
42 opmath_t self_val_f = self_val;
43 opmath_t end_val_f = end_val;
44 opmath_t weight_val_f = weight_val;
45 return lerp(self_val, end_val, weight_val);
46 });
47 });
48 #endif
49 } else {
50 AT_DISPATCH_FLOATING_TYPES_AND2(
51 at::ScalarType::Half, at::ScalarType::BFloat16,
52 dtype, "lerp_cuda",
53 [&] {
54 at::native::gpu_kernel(
55 iter,
56 [] GPU_LAMBDA(
57 scalar_t self_val,
58 scalar_t end_val,
59 scalar_t weight_val) -> scalar_t {
60 return lerp(self_val, end_val, weight_val);
61 });
62 });
63 }
64 }
65
66 CONSTEXPR_EXCEPT_WIN_CUDA char lerp_scalar_name[] = "lerp_scalar";
lerp_scalar_kernel(at::TensorIteratorBase & iter,const c10::Scalar & weight)67 void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) {
68 auto dtype = iter.common_dtype();
69 if (at::isComplexType(dtype)) {
70 #if AT_USE_JITERATOR()
71 static const auto lerp_scalar_string = jiterator_stringify(
72 template <typename T>
73 T lerp_scalar(T self_val, T end_val, T weight_val) {
74 return (std::abs(weight_val) < 0.5)
75 ? self_val + weight_val * (end_val - self_val)
76 : end_val -
77 (end_val - self_val) * (static_cast<T>(1) - weight_val);
78 }
79 ); // lerp_scalar_string
80 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
81 using opmath_t = at::opmath_type<scalar_t>;
82 auto weight_val = weight.to<opmath_t>();
83 jitted_gpu_kernel<
84 /*name=*/ lerp_scalar_name,
85 /*return_dtype=*/ scalar_t,
86 /*common_dtype=*/ scalar_t,
87 /*arity=*/ 2>(
88 iter,
89 lerp_scalar_string,
90 /*scalar_pos=*/ at::cuda::jit::BinaryFuncVariant::NoScalar,
91 /*scalar_val=*/ 0,
92 /*extra_args=*/ std::make_tuple(weight_val));
93 });
94 #else
95 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
96 using opmath_t = at::opmath_type<scalar_t>;
97 auto weight_val = weight.to<opmath_t>();
98 at::native::gpu_kernel(
99 iter,
100 [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
101 opmath_t self_val_f = self_val;
102 opmath_t end_val_f = end_val;
103 return lerp(self_val, end_val, weight_val);
104 });
105 });
106 #endif
107 } else {
108 AT_DISPATCH_FLOATING_TYPES_AND2(
109 at::ScalarType::Half, at::ScalarType::BFloat16,
110 dtype, "lerp_cuda",
111 [&]{
112 using opmath_t = at::opmath_type<scalar_t>;
113 auto weight_val = weight.to<opmath_t>();
114 at::native::gpu_kernel(
115 iter, [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
116 return lerp(self_val, end_val, weight_val);
117 });
118 });
119 }
120 }
121
122 } // anonymous namespace
123
124 REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
125 REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
126
127 } // namespace at::native
128