xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Ternary and higher-order pointwise operations
2 #define TORCH_ASSERT_NO_OPERATORS
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/PointwiseOps.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cpu/Loops.h>
7 #include <c10/core/Scalar.h>
8 #include <ATen/cpu/vec/functional.h>
9 namespace at::native {
10 namespace {
11 
addcmul_cpu_kernel(TensorIteratorBase & iter,const Scalar & value)12 static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
13   ScalarType dtype = iter.common_dtype();
14   if (at::isReducedFloatingType(dtype)) {
15     AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcmul_cpu_out", [&]() {
16       float float_val = value.to<float>();
17       auto float_vec = Vectorized<float>(float_val);
18       cpu_kernel_vec(
19           iter,
20           [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
21             return float(self_val) + float_val * float(t1_val) * float(t2_val);
22           },
23           [=](Vectorized<scalar_t> self_vec,
24             Vectorized<scalar_t> t1_vec,
25             Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
26             auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
27             auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
28             auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
29             self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0;
30             self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1;
31             return convert_from_float<scalar_t>(self_vec0, self_vec1);
32           });
33     });
34   } else {
35     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::ComplexHalf,
36                                            dtype, "addcmul_cpu_out", [&] {
37       scalar_t scalar_val = value.to<scalar_t>();
38       auto scalar_vec = Vectorized<scalar_t>(scalar_val);
39       cpu_kernel_vec(
40           iter,
41           [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
42             return self_val + scalar_val * t1_val * t2_val;
43           },
44           [=](Vectorized<scalar_t> self_vec,
45               Vectorized<scalar_t> t1_vec,
46               Vectorized<scalar_t> t2_vec) {
47             return self_vec + scalar_vec * t1_vec * t2_vec;
48           });
49     });
50   }
51 }
52 
addcdiv_cpu_kernel(TensorIteratorBase & iter,const Scalar & value)53 static void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
54   ScalarType dtype = iter.common_dtype();
55   if (at::isReducedFloatingType(dtype)) {
56     AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcdiv_cpu_out", [&]() {
57       float float_val = value.to<float>();
58       auto float_vec = Vectorized<float>(float_val);
59       cpu_kernel_vec(
60           iter,
61           [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
62             return float(self_val) + float_val * float(t1_val) / float(t2_val);
63           },
64           [=](Vectorized<scalar_t> self_vec,
65               Vectorized<scalar_t> t1_vec,
66               Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
67               auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
68               auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
69               auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
70               self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0;
71               self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1;
72               return convert_from_float<scalar_t>(self_vec0, self_vec1);
73           });
74     });
75   } else {
76     AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcdiv_cpu_out", [&] {
77       scalar_t scalar_val = value.to<scalar_t>();
78       auto scalar_vec = Vectorized<scalar_t>(scalar_val);
79       cpu_kernel_vec(
80           iter,
81           [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
82             return self_val + scalar_val * t1_val / t2_val;
83           },
84           [=](Vectorized<scalar_t> self_vec,
85               Vectorized<scalar_t> t1_vec,
86               Vectorized<scalar_t> t2_vec) {
87             return self_vec + scalar_vec * t1_vec / t2_vec;
88           });
89     });
90   }
91 }
92 
smooth_l1_backward_cpu_kernel(TensorIterator & iter,const Scalar & norm,double beta)93 static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
94   ScalarType dtype = iter.dtype(0);
95   if (dtype == kBFloat16) {
96     auto norm_val = norm.to<float>();
97     float beta_val(beta);
98     auto norm_val_vec = Vectorized<float>(norm_val);
99     auto beta_val_vec = Vectorized<float>(beta_val);
100     const auto neg_1_vec = Vectorized<float>(-1);
101     const auto zero_vec = Vectorized<float>(0);
102     const auto pos_1_vec = Vectorized<float>(1);
103     cpu_kernel_vec(iter,
104       [=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
105         const auto x = float(input) - float(target);
106         if (x <= -beta){
107           return -norm_val * float(grad_output);
108         }else if (x >= beta){
109           return norm_val * float(grad_output);
110         }else{
111           return norm_val * x * float(grad_output) / beta;
112         }
113       },
114       [norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
115          Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
116         // using two blendv calls to simulate the 3 cases
117         // 1        if  x >= beta
118         // -1       if x <= -beta
119         // x / beta if |x| < beta
120         auto [input0, input1] = convert_bfloat16_float(input);
121         auto [target0, target1] = convert_bfloat16_float(target);
122         auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
123         auto x = input0 - target0;
124         auto pos_or_neg_1_vec = Vectorized<float>::blendv(
125             neg_1_vec, pos_1_vec, x > zero_vec);
126         auto x_abs = x.abs();
127         auto output = Vectorized<float>::blendv(
128             x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
129         input0 = norm_val_vec * output * grad_output0;
130 
131         x = input1 - target1;
132         pos_or_neg_1_vec = Vectorized<float>::blendv(
133             neg_1_vec, pos_1_vec, x > zero_vec);
134         x_abs = x.abs();
135         output = Vectorized<float>::blendv(
136             x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
137         input1 = norm_val_vec * output * grad_output1;
138         return convert_float_bfloat16(input0, input1);
139       }
140     );
141   } else {
142     AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
143     auto norm_val = norm.to<scalar_t>();
144     scalar_t beta_val(beta);
145     auto norm_val_vec = Vectorized<scalar_t>(norm_val);
146     auto beta_val_vec = Vectorized<scalar_t>(beta_val);
147     const auto neg_1_vec = Vectorized<scalar_t>(-1);
148     const auto zero_vec = Vectorized<scalar_t>(0);
149     const auto pos_1_vec = Vectorized<scalar_t>(1);
150     cpu_kernel_vec(iter,
151       [=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
152         const auto x = input - target;
153         if (x <= -beta)
154           return -norm_val * grad_output;
155         else if (x >= beta)
156           return norm_val * grad_output;
157         else
158           return norm_val * x * grad_output / beta;
159       },
160       [norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
161          Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
162         // using two blendv calls to simulate the 3 cases
163         // 1        if  x >= beta
164         // -1       if x <= -beta
165         // x / beta if |x| < beta
166         const auto x = input - target;
167         const auto pos_or_neg_1_vec = Vectorized<scalar_t>::blendv(
168             neg_1_vec, pos_1_vec, x > zero_vec);
169         const auto x_abs = x.abs();
170         const auto output = Vectorized<scalar_t>::blendv(
171             x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
172         return norm_val_vec * output * grad_output;
173       }
174     );
175   });
176   }
177 }
178 
huber_backward_cpu_kernel(TensorIterator & iter,const Scalar & norm,double delta)179 static void huber_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double delta) {
180   ScalarType dtype = iter.dtype(0);
181   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, dtype, "huber_backward_cpu_out", [&] {
182     auto norm_val = norm.to<scalar_t>();
183     scalar_t delta_val(delta);
184     auto norm_val_vec = Vectorized<scalar_t>(norm_val);
185     auto delta_val_vec = Vectorized<scalar_t>(delta_val);
186     const auto neg_1_vec = Vectorized<scalar_t>(-1);
187     const auto zero_vec = Vectorized<scalar_t>(0);
188     const auto pos_1_vec = Vectorized<scalar_t>(1);
189     cpu_kernel_vec(iter,
190       [=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
191         const auto x = input - target;
192         if (x <= -delta) {
193           return -norm_val * grad_output * delta;
194         } else if (x >= delta) {
195           return norm_val * grad_output * delta;
196         } else {
197           return norm_val * x * grad_output;
198         }
199       },
200       [norm_val_vec, delta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
201          Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
202         // using two blendv calls to simulate the 3 cases
203         // delta     if  x >= delta
204         // -delta    if x <= -delta
205         // x        if |x| < delta
206         const auto x = input - target;
207         const auto pos_or_neg_1_vec = Vectorized<scalar_t>::blendv(
208             neg_1_vec, pos_1_vec, x > zero_vec);
209         const auto x_abs = x.abs();
210         const auto output = Vectorized<scalar_t>::blendv(
211             x, pos_or_neg_1_vec * delta_val_vec, x_abs >= delta_val_vec);
212         return norm_val_vec * output * grad_output;
213       }
214     );
215   });
216 }
217 
mse_backward_cpu_kernel(TensorIterator & iter,const Scalar & value)218 static void mse_backward_cpu_kernel(TensorIterator& iter, const Scalar& value) {
219   ScalarType dtype = iter.dtype(0);
220   AT_DISPATCH_ALL_TYPES(dtype, "mse_backward_cpu_out", [&] {
221     scalar_t scalar_val = value.to<scalar_t>();
222     auto scalar_vec = Vectorized<scalar_t>(scalar_val);
223     cpu_kernel_vec(
224         iter,
225         [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
226           return scalar_val * (self_val - t1_val) * t2_val;
227         },
228         [=](Vectorized<scalar_t> self_vec,
229             Vectorized<scalar_t> t1_vec,
230             Vectorized<scalar_t> t2_vec) {
231           return scalar_vec * (self_vec - t1_vec) *  t2_vec;
232     });
233   });
234 }
235 
236 } // anonymous namespace
237 
238 REGISTER_DISPATCH(addcmul_stub, &addcmul_cpu_kernel);
239 REGISTER_DISPATCH(addcdiv_stub, &addcdiv_cpu_kernel);
240 REGISTER_DISPATCH(smooth_l1_backward_stub, &smooth_l1_backward_cpu_kernel);
241 REGISTER_DISPATCH(huber_backward_stub, &huber_backward_cpu_kernel);
242 REGISTER_DISPATCH(mse_backward_stub, &mse_backward_cpu_kernel);
243 
244 } // namespace at::native
245