xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Lerp.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/Lerp.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/OpMathType.h>
8 
9 namespace at::native {
10 namespace {
11 
12 CONSTEXPR_EXCEPT_WIN_CUDA char lerp_tensor_name[] = "lerp_tensor";
lerp_tensor_kernel(at::TensorIteratorBase & iter)13 void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
14   auto dtype = iter.common_dtype();
15   if(at::isComplexType(dtype)) {
16 #if AT_USE_JITERATOR()
17   static const auto lerp_tensor_string = jiterator_stringify(
18       template <typename T>
19       T lerp_tensor(T self_val, T end_val, T weight_val) {
20         return (std::abs(weight_val) < 0.5)
21             ? self_val + weight_val * (end_val - self_val)
22             : end_val -
23                 (end_val - self_val) * (static_cast<T>(1) - weight_val);
24       }
25   ); // lerp_tensor_string
26   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
27         jitted_gpu_kernel<
28           /*name=*/ lerp_tensor_name,
29           /*return_dtype=*/ scalar_t,
30           /*common_dtype=*/ scalar_t,
31           /*arity=*/ 3>(iter, lerp_tensor_string);
32       });
33 #else
34   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
35       using opmath_t = at::opmath_type<scalar_t>;
36       at::native::gpu_kernel(
37         iter,
38         [] GPU_LAMBDA(
39             scalar_t self_val,
40             scalar_t end_val,
41             scalar_t weight_val) -> scalar_t {
42            opmath_t self_val_f = self_val;
43            opmath_t end_val_f = end_val;
44            opmath_t weight_val_f = weight_val;
45           return lerp(self_val, end_val, weight_val);
46         });
47       });
48 #endif
49   } else {
50   AT_DISPATCH_FLOATING_TYPES_AND2(
51       at::ScalarType::Half, at::ScalarType::BFloat16,
52       dtype, "lerp_cuda",
53       [&] {
54         at::native::gpu_kernel(
55             iter,
56             [] GPU_LAMBDA(
57                 scalar_t self_val,
58                 scalar_t end_val,
59                 scalar_t weight_val) -> scalar_t {
60               return lerp(self_val, end_val, weight_val);
61             });
62       });
63   }
64 }
65 
66 CONSTEXPR_EXCEPT_WIN_CUDA char lerp_scalar_name[] = "lerp_scalar";
lerp_scalar_kernel(at::TensorIteratorBase & iter,const c10::Scalar & weight)67 void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) {
68   auto dtype = iter.common_dtype();
69   if (at::isComplexType(dtype)) {
70 #if AT_USE_JITERATOR()
71   static const auto lerp_scalar_string = jiterator_stringify(
72       template <typename T>
73       T lerp_scalar(T self_val, T end_val, T weight_val) {
74         return (std::abs(weight_val) < 0.5)
75             ? self_val + weight_val * (end_val - self_val)
76             : end_val -
77                 (end_val - self_val) * (static_cast<T>(1) - weight_val);
78       }
79   ); // lerp_scalar_string
80   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
81       using opmath_t = at::opmath_type<scalar_t>;
82       auto weight_val = weight.to<opmath_t>();
83       jitted_gpu_kernel<
84         /*name=*/ lerp_scalar_name,
85         /*return_dtype=*/ scalar_t,
86         /*common_dtype=*/ scalar_t,
87         /*arity=*/ 2>(
88         iter,
89         lerp_scalar_string,
90         /*scalar_pos=*/ at::cuda::jit::BinaryFuncVariant::NoScalar,
91         /*scalar_val=*/ 0,
92         /*extra_args=*/ std::make_tuple(weight_val));
93   });
94 #else
95   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] {
96     using opmath_t = at::opmath_type<scalar_t>;
97     auto weight_val = weight.to<opmath_t>();
98     at::native::gpu_kernel(
99         iter,
100         [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
101           opmath_t self_val_f = self_val;
102           opmath_t end_val_f = end_val;
103           return lerp(self_val, end_val, weight_val);
104         });
105   });
106 #endif
107   } else {
108   AT_DISPATCH_FLOATING_TYPES_AND2(
109       at::ScalarType::Half, at::ScalarType::BFloat16,
110       dtype, "lerp_cuda",
111       [&]{
112         using opmath_t = at::opmath_type<scalar_t>;
113         auto weight_val = weight.to<opmath_t>();
114         at::native::gpu_kernel(
115             iter, [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
116               return lerp(self_val, end_val, weight_val);
117             });
118       });
119     }
120 }
121 
122 } // anonymous namespace
123 
124 REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
125 REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
126 
127 } // namespace at::native
128