xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/BinaryOps.h>
3 
4 #include <cmath>
5 
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <ATen/cpu/vec/vec.h>
12 #include <ATen/native/Math.h>
13 #include <ATen/native/TensorIterator.h>
14 #include <ATen/native/cpu/LogAddExp.h>
15 #include <ATen/native/cpu/Loops.h>
16 #include <c10/macros/Macros.h>
17 #include <c10/util/TypeSafeSignMath.h>
18 #include <c10/util/generic_math.h>
19 
20 namespace at::native {
21 
22 namespace {
23 
24 using namespace vec;
25 
26 template <
27     typename scalar_t,
28     typename Op,
29     typename opmath_t = at::opmath_type<scalar_t>,
30     typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
binary_op_scalar(const Vectorized<scalar_t> & a,opmath_t b,const Op & op)31 inline Vectorized<scalar_t> binary_op_scalar(
32     const Vectorized<scalar_t>& a,
33     opmath_t b,
34     const Op& op) {
35   Vectorized<opmath_t> vec_b(b);
36   auto [a0, a1] = convert_to_float<scalar_t>(a);
37   return convert_from_float<scalar_t>(op(a0, vec_b), op(a1, vec_b));
38 }
39 
add_clamp_kernel(TensorIterator & iter,const Scalar & alpha_scalar,const Scalar & min_val,const Scalar & max_val)40 void add_clamp_kernel(
41     TensorIterator& iter,
42     const Scalar& alpha_scalar,
43     const Scalar& min_val,
44     const Scalar& max_val) {
45   AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() {
46     auto alpha = alpha_scalar.to<scalar_t>();
47     auto alpha_vec = Vectorized<scalar_t>(alpha);
48     auto min_scalar = min_val.to<scalar_t>();
49     auto min_vec = Vectorized<scalar_t>(min_scalar);
50     auto max_scalar = max_val.to<scalar_t>();
51     auto max_vec = Vectorized<scalar_t>(max_scalar);
52     cpu_kernel_vec(
53         iter,
54         [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t {
55           return std::min(
56               max_scalar,
57               std::max(min_scalar, static_cast<scalar_t>(a + alpha * b)));
58         },
59         [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
60             __ubsan_ignore_undefined__ {
61               auto add_clamp_res = vec::fmadd(b, alpha_vec, a);
62               add_clamp_res = vec::clamp_min(add_clamp_res, min_vec);
63               add_clamp_res = vec::clamp_max(add_clamp_res, max_vec);
64               return add_clamp_res;
65             });
66   });
67 }
68 
atan2_kernel(TensorIteratorBase & iter)69 void atan2_kernel(TensorIteratorBase& iter) {
70   AT_DISPATCH_FLOATING_TYPES_AND2(
71       kBFloat16, kHalf, iter.dtype(), "atan2_cpu", [&]() {
72         cpu_kernel_vec(
73             iter,
74             [=](scalar_t a, scalar_t b) -> scalar_t {
75               return std::atan2(a, b);
76             },
77             [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
78               return a.atan2(b);
79             });
80       });
81 }
82 
83 #if !defined(C10_MOBILE)
84 #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
85   AT_DISPATCH_V2(                \
86       TYPE,                                              \
87       NAME,                                              \
88       AT_WRAP(__VA_ARGS__), \
89       kComplexHalf,                                      \
90       kHalf,                                             \
91       kBool,                                             \
92       kBFloat16,                                         \
93       AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
94 #define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
95   AT_DISPATCH_V2(               \
96       TYPE,                                             \
97       NAME,                                             \
98       AT_WRAP(__VA_ARGS__), \
99       kComplexHalf,                                     \
100       kHalf,                                            \
101       kBFloat16,                                        \
102       AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
103 #define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
104   AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__),       \
105       kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
106 #else
107 #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
108   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                \
109       kComplexHalf, kHalf, kBool, kBFloat16, TYPE, NAME, __VA_ARGS__)
110 #define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
111   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(               \
112       kComplexHalf, kHalf, kBFloat16, TYPE, NAME, __VA_ARGS__)
113 #define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
114   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(       \
115       kHalf, kBFloat16, TYPE, NAME, __VA_ARGS__)
116 #endif
117 
mul_kernel(TensorIteratorBase & iter)118 void mul_kernel(TensorIteratorBase& iter) {
119   auto dtype = iter.common_dtype();
120   if (dtype == ScalarType::Bool) {
121     cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; });
122   } else if (dtype == kComplexHalf) {
123     cpu_kernel(
124         iter,
125         [=](c10::complex<at::Half> a,
126             c10::complex<at::Half> b) -> c10::complex<at::Half> {
127           using comp_t = c10::complex<float>;
128           return comp_t{a} * comp_t{b};
129         });
130   } else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
131     AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "mul_cpu_reduced_float", [&]() {
132       using opmath_t = at::opmath_type<scalar_t>;
133       opmath_t b = iter.original_scalar_value<opmath_t>(2);
134       iter.remove_operand(2);
135       cpu_kernel_vec(
136           iter,
137           [=](scalar_t a) __ubsan_ignore_undefined__ -> scalar_t {
138             return static_cast<opmath_t>(a) * b;
139           },
140           [=](Vectorized<scalar_t> a) __ubsan_ignore_undefined__ {
141             return binary_op_scalar(
142                 a,
143                 b,
144                 [](const Vectorized<opmath_t>& x,
145                    const Vectorized<opmath_t>& y) { return x * y; });
146           });
147     });
148   } else {
149     _AT_DISPATCH_MUL_TYPES(dtype, "mul_cpu", [&]() {
150       cpu_kernel_vec(
151           iter,
152           [=](scalar_t a, scalar_t b)
153               __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
154           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
155               __ubsan_ignore_undefined__ { return a * b; });
156     });
157   }
158 }
159 
div_true_kernel(TensorIteratorBase & iter)160 void div_true_kernel(TensorIteratorBase& iter) {
161   const auto dtype = iter.common_dtype();
162   if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
163     AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_cpu_reduced_float", [&]() {
164       using opmath_t = at::opmath_type<scalar_t>;
165       opmath_t b = iter.original_scalar_value<opmath_t>(2);
166       iter.remove_operand(2);
167       cpu_kernel_vec(
168           iter,
169           [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
170             return static_cast<opmath_t>(a) / b;
171           },
172           [=](Vectorized<scalar_t> a) {
173             return binary_op_scalar(
174                 a,
175                 b,
176                 [](const Vectorized<opmath_t>& x,
177                    const Vectorized<opmath_t>& y) { return x / y; });
178           });
179     });
180   } else {
181     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
182         kBFloat16, kHalf, dtype, "div_cpu", [&]() {
183           cpu_kernel_vec(
184               iter,
185               [](scalar_t a, scalar_t b)
186                   __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
187                     return a / b;
188                   },
189               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
190                 return a / b;
191               });
192         });
193   }
194 }
195 
div_trunc_kernel(TensorIteratorBase & iter)196 void div_trunc_kernel(TensorIteratorBase& iter) {
197   const auto dtype = iter.common_dtype();
198   if (isIntegralType(dtype, /*includeBool*/ false)) {
199     // There's no SIMD integer division, so don't try to vectorize it.
200     // TODO: if the divisor is a scalar, rewrite as multiplication by a
201     // constant.
202     AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cpu", [&]() {
203       cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
204         TORCH_CHECK(b != 0, "ZeroDivisionError");
205         return a / b;
206       });
207     });
208   } else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
209     AT_DISPATCH_REDUCED_FLOATING_TYPES(
210         dtype, "div_trunc_cpu_reduced_float", [&]() {
211           using opmath_t = at::opmath_type<scalar_t>;
212           opmath_t b = iter.original_scalar_value<opmath_t>(2);
213           iter.remove_operand(2);
214           cpu_kernel_vec(
215               iter,
216               [=](scalar_t a)
217                   __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
218                     return std::trunc(static_cast<opmath_t>(a) / b);
219                   },
220               [=](Vectorized<scalar_t> a) {
221                 return binary_op_scalar(
222                     a,
223                     b,
224                     [](const Vectorized<opmath_t>& x,
225                        const Vectorized<opmath_t>& y) {
226                       return (x / y).trunc();
227                     });
228               });
229         });
230   } else {
231     AT_DISPATCH_FLOATING_TYPES_AND2(
232         kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
233           cpu_kernel_vec(
234               iter,
235               [](scalar_t a, scalar_t b)
236                   __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
237                     return std::trunc(a / b);
238                   },
239               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
240                 return (a / b).trunc();
241               });
242         });
243   }
244 }
245 
246 template <typename scalar_t>
div_floor_floating_vec(const Vectorized<scalar_t> & a,const Vectorized<scalar_t> & b)247 inline Vectorized<scalar_t> div_floor_floating_vec(
248     const Vectorized<scalar_t>& a,
249     const Vectorized<scalar_t>& b) {
250   using vec_t = Vectorized<scalar_t>;
251   const auto basic_div = a / b;
252   vec_t inf(std::numeric_limits<scalar_t>::infinity());
253   auto mod = a.fmod(b);
254   // Fixup for a case that isn't properly handled by Sleef_fmod
255   auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
256   auto div = floor / b;
257   const auto zero = vec_t(0);
258   auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
259   const auto one = vec_t(1);
260   div = vec_t::blendv(div, div - one, mask);
261   auto floordiv = div.floor();
262   mask = (div - floordiv) > vec_t(0.5);
263   floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
264   floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
265   floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
266   return floordiv;
267 };
268 
div_floor_kernel(TensorIteratorBase & iter)269 void div_floor_kernel(TensorIteratorBase& iter) {
270   const auto dtype = iter.common_dtype();
271   if (dtype == kByte) {
272     // In the special case of unsigned integer division, floor division is
273     // equivalent to truncation division (since the signs of the divisor and
274     // dividend are always the same)
275     return div_trunc_kernel(iter);
276   } else if (isIntegralType(dtype, /*includeBool*/ false)) {
277     // There's no SIMD integer division, so don't try to vectorize it.
278     AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cpu", [&]() {
279       cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
280         TORCH_CHECK(b != 0, "ZeroDivisionError");
281         return c10::div_floor_integer(a, b);
282       });
283     });
284   } else {
285     // See NOTE: [Floor Division in Python]
286     if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
287       AT_DISPATCH_REDUCED_FLOATING_TYPES(
288           dtype, "div_floor_cpu_reduced_float", [&]() {
289             using opmath_t = at::opmath_type<scalar_t>;
290             opmath_t b = iter.original_scalar_value<opmath_t>(2);
291             iter.remove_operand(2);
292             using vec_t = Vectorized<opmath_t>;
293             cpu_kernel_vec(
294                 iter,
295                 [=](scalar_t a) -> scalar_t {
296                   return c10::div_floor_floating(static_cast<opmath_t>(a), b);
297                 },
298                 [=](Vectorized<scalar_t> a) {
299                   return binary_op_scalar(
300                       a, b, [](const vec_t& x, const vec_t& y) {
301                         return div_floor_floating_vec(x, y);
302                       });
303                 });
304           });
305     } else {
306       AT_DISPATCH_FLOATING_TYPES_AND2(
307           kBFloat16, kHalf, dtype, "div_floor_cpu", [&]() {
308             using vec_t = Vectorized<scalar_t>;
309             cpu_kernel_vec(
310                 iter,
311                 [](scalar_t a, scalar_t b) -> scalar_t {
312                   return c10::div_floor_floating(a, b);
313                 },
314                 [](vec_t a, vec_t b) -> vec_t {
315                   return div_floor_floating_vec(a, b);
316                 });
317           });
318     }
319   }
320 }
321 
remainder_kernel(TensorIteratorBase & iter)322 void remainder_kernel(TensorIteratorBase& iter) {
323   if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) {
324     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() {
325       cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
326         TORCH_CHECK(b != 0, "ZeroDivisionError");
327         scalar_t r = a % b;
328         if ((r != 0) && (c10::is_negative(r) != c10::is_negative(b))) {
329           r += b;
330         }
331         return r;
332       });
333     });
334   } else if (iter.common_dtype() == kBFloat16) {
335     cpu_kernel_vec(
336         iter,
337         [=](BFloat16 a, BFloat16 b)
338             __ubsan_ignore_float_divide_by_zero__ -> BFloat16 {
339               float a0 = static_cast<float>(a);
340               float b0 = static_cast<float>(b);
341               float mod0 = std::fmod(a0, b0);
342               if ((mod0 != 0) && ((b0 < 0) != (mod0 < 0))) {
343                 mod0 += b0;
344               }
345               return mod0;
346             },
347         [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
348           auto [a0, a1] = convert_bfloat16_float(a);
349           auto [b0, b1] = convert_bfloat16_float(b);
350           auto mod0 = a0.fmod(b0);
351           auto mod1 = a1.fmod(b1);
352           const auto zero = Vectorized<float>(0);
353           auto mask0 = (mod0 != zero) & ((b0 < zero) ^ (mod0 < zero));
354           auto mask1 = (mod1 != zero) & ((b1 < zero) ^ (mod1 < zero));
355           a0 = Vectorized<float>::blendv(mod0, mod0 + b0, mask0);
356           a1 = Vectorized<float>::blendv(mod1, mod1 + b1, mask1);
357           return convert_float_bfloat16(a0, a1);
358         });
359   } else {
360     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
361         iter.common_dtype(), "remainder_cpu", [&]() {
362           cpu_kernel_vec(
363               iter,
364               [=](scalar_t a, scalar_t b)
365                   __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
366                     scalar_t mod = std::fmod(a, b);
367                     if ((mod != 0) && ((b < 0) != (mod < 0)))
368                       mod += b;
369                     return mod;
370                   },
371               [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
372                 auto mod = a.fmod(b);
373                 const auto zero = Vectorized<scalar_t>(0);
374                 auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
375                 return Vectorized<scalar_t>::blendv(mod, mod + b, mask);
376               });
377         });
378   }
379 }
380 
bitwise_and_kernel(TensorIteratorBase & iter)381 void bitwise_and_kernel(TensorIteratorBase& iter) {
382   if (iter.dtype() == ScalarType::Bool) {
383     cpu_kernel(iter, [](bool a, bool b) { return a && b; });
384   } else {
385     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
386       cpu_kernel_vec(
387           iter,
388           [](scalar_t a, scalar_t b) -> scalar_t { return a & b; },
389           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a & b; });
390     });
391   }
392 }
393 
bitwise_or_kernel(TensorIteratorBase & iter)394 void bitwise_or_kernel(TensorIteratorBase& iter) {
395   if (iter.dtype() == ScalarType::Bool) {
396     cpu_kernel(iter, [](bool a, bool b) { return a || b; });
397   } else {
398     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() {
399       cpu_kernel_vec(
400           iter,
401           [](scalar_t a, scalar_t b) -> scalar_t { return a | b; },
402           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a | b; });
403     });
404   }
405 }
406 
bitwise_xor_kernel(TensorIteratorBase & iter)407 void bitwise_xor_kernel(TensorIteratorBase& iter) {
408   if (iter.dtype() == ScalarType::Bool) {
409     // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps
410     // this operation for both Boolean and integral types.
411     cpu_kernel(iter, [](bool a, bool b) { return a != b; });
412   } else {
413     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cpu", [&]() {
414       cpu_kernel_vec(
415           iter,
416           [](scalar_t a, scalar_t b) -> scalar_t { return a ^ b; },
417           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a ^ b; });
418     });
419   }
420 }
421 
lshift_kernel(TensorIteratorBase & iter)422 void lshift_kernel(TensorIteratorBase& iter) {
423   AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
424     cpu_kernel_vec(
425         iter,
426         [](scalar_t a, scalar_t b) -> scalar_t {
427           constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
428           if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
429               (b >= max_shift)) {
430             return 0;
431           }
432           return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
433         },
434         [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a << b; });
435   });
436 }
437 
logical_and_kernel(TensorIterator & iter)438 void logical_and_kernel(TensorIterator& iter) {
439   // See Note [special-case bool outputs]
440   if (iter.dtype() == ScalarType::Bool) {
441     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
442         kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
443           cpu_kernel(
444               iter, [](scalar_t a, scalar_t b) -> bool { return a && b; });
445         });
446   } else {
447     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
448         kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
449           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
450             return static_cast<scalar_t>(a && b);
451           });
452         });
453   }
454 }
455 
logical_or_kernel(TensorIterator & iter)456 void logical_or_kernel(TensorIterator& iter) {
457   // See Note [special-case bool outputs]
458   if (iter.dtype() == ScalarType::Bool) {
459     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
460         kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
461           cpu_kernel(
462               iter, [](scalar_t a, scalar_t b) -> bool { return a || b; });
463         });
464   } else {
465     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
466         kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
467           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
468             return static_cast<scalar_t>(a || b);
469           });
470         });
471   }
472 }
473 
logical_xor_kernel(TensorIterator & iter)474 void logical_xor_kernel(TensorIterator& iter) {
475   // See Note [special-case bool outputs]
476   if (iter.dtype() == ScalarType::Bool) {
477     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
478         kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
479           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool {
480             return bool(a) != bool(b);
481           });
482         });
483   } else {
484     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
485         kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
486           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
487             return static_cast<scalar_t>(bool(a) != bool(b));
488           });
489         });
490   }
491 }
492 
rshift_kernel(TensorIteratorBase & iter)493 void rshift_kernel(TensorIteratorBase& iter) {
494   AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
495     cpu_kernel_vec(
496         iter,
497         [](scalar_t a, scalar_t b) -> scalar_t {
498           // right shift value to retain sign bit for signed and no bits for
499           // unsigned
500           constexpr scalar_t max_shift =
501               sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
502           if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
503               (b >= max_shift)) {
504             return a >> max_shift;
505           }
506           return a >> b;
507         },
508         [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a >> b; });
509   });
510 }
511 
lt_kernel(TensorIteratorBase & iter)512 void lt_kernel(TensorIteratorBase& iter) {
513   // See Note [special-case bool outputs]
514   if (iter.dtype() == ScalarType::Bool) {
515     AT_DISPATCH_ALL_TYPES_AND3(
516         kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
517           cpu_kernel(
518               iter, [](scalar_t a, scalar_t b) -> bool { return a < b; });
519         });
520   } else {
521     AT_DISPATCH_ALL_TYPES_AND2(
522         kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
523           cpu_kernel_vec(
524               iter,
525               [](scalar_t a, scalar_t b) -> scalar_t { return a < b; },
526               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
527                   -> Vectorized<scalar_t> { return a.lt(b); });
528         });
529   }
530 }
531 
le_kernel(TensorIteratorBase & iter)532 void le_kernel(TensorIteratorBase& iter) {
533   // See Note [special-case bool outputs]
534   if (iter.dtype() == ScalarType::Bool) {
535     AT_DISPATCH_ALL_TYPES_AND3(
536         kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
537           cpu_kernel(
538               iter, [](scalar_t a, scalar_t b) -> bool { return a <= b; });
539         });
540   } else {
541     AT_DISPATCH_ALL_TYPES_AND2(
542         kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
543           cpu_kernel_vec(
544               iter,
545               [](scalar_t a, scalar_t b) -> scalar_t { return a <= b; },
546               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
547                   -> Vectorized<scalar_t> { return a.le(b); });
548         });
549   }
550 }
551 
gt_kernel(TensorIteratorBase & iter)552 void gt_kernel(TensorIteratorBase& iter) {
553   // See Note [special-case bool outputs]
554   if (iter.dtype() == ScalarType::Bool) {
555     AT_DISPATCH_ALL_TYPES_AND3(
556         kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
557           cpu_kernel(
558               iter, [](scalar_t a, scalar_t b) -> bool { return a > b; });
559         });
560   } else {
561     AT_DISPATCH_ALL_TYPES_AND2(
562         kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
563           cpu_kernel_vec(
564               iter,
565               [](scalar_t a, scalar_t b) -> scalar_t { return a > b; },
566               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
567                   -> Vectorized<scalar_t> { return a.gt(b); });
568         });
569   }
570 }
571 
ge_kernel(TensorIteratorBase & iter)572 void ge_kernel(TensorIteratorBase& iter) {
573   // See Note [special-case bool outputs]
574   if (iter.dtype() == ScalarType::Bool) {
575     AT_DISPATCH_ALL_TYPES_AND3(
576         kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
577           cpu_kernel(
578               iter, [](scalar_t a, scalar_t b) -> bool { return a >= b; });
579         });
580   } else {
581     AT_DISPATCH_ALL_TYPES_AND2(
582         kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
583           cpu_kernel_vec(
584               iter,
585               [](scalar_t a, scalar_t b) -> scalar_t { return a >= b; },
586               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
587                   -> Vectorized<scalar_t> { return a.ge(b); });
588         });
589   }
590 }
591 
eq_kernel(TensorIteratorBase & iter)592 void eq_kernel(TensorIteratorBase& iter) {
593   // See Note [special-case bool outputs]
594   if (iter.dtype() == ScalarType::Bool) {
595     _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
596       cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a == b; });
597     });
598   } else {
599     _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
600       cpu_kernel_vec(
601           iter,
602           [](scalar_t a, scalar_t b) -> scalar_t {
603             return static_cast<scalar_t>(a == b);
604           },
605           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
606               -> Vectorized<scalar_t> { return a.eq(b); });
607     });
608   }
609 }
610 
ne_kernel(TensorIteratorBase & iter)611 void ne_kernel(TensorIteratorBase& iter) {
612   // See Note [special-case bool outputs]
613   if (iter.dtype() == ScalarType::Bool) {
614     _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
615       cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a != b; });
616     });
617   } else {
618     _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
619       cpu_kernel_vec(
620           iter,
621           [](scalar_t a, scalar_t b) -> scalar_t {
622             return static_cast<scalar_t>(a != b);
623           },
624           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
625               -> Vectorized<scalar_t> { return a.ne(b); });
626     });
627   }
628 }
629 
maximum_kernel(TensorIteratorBase & iter)630 void maximum_kernel(TensorIteratorBase& iter) {
631   if (iter.dtype() == ScalarType::Bool) {
632     cpu_kernel(iter, [](bool a, bool b) -> bool { return a || b; });
633   } else if (isIntegralType(iter.dtype(), /*includeBool=*/false)) {
634     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "maximum_cpu", [&]() {
635       cpu_kernel_vec(
636           iter,
637           [](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); },
638           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
639             return at::vec::maximum(a, b);
640           });
641     });
642   } else {
643     AT_DISPATCH_FLOATING_TYPES_AND2(
644         at::ScalarType::Half,
645         at::ScalarType::BFloat16,
646         iter.dtype(),
647         "maximum_cpu",
648         [&]() {
649           cpu_kernel_vec(
650               iter,
651               [](scalar_t a, scalar_t b) -> scalar_t {
652                 if (a != a || b != b) {
653                   return std::numeric_limits<scalar_t>::quiet_NaN();
654                 } else {
655                   return std::max(a, b);
656                 }
657               },
658               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
659                 return at::vec::maximum(a, b);
660               });
661         });
662   }
663 }
664 
minimum_kernel(TensorIteratorBase & iter)665 void minimum_kernel(TensorIteratorBase& iter) {
666   if (iter.dtype() == ScalarType::Bool) {
667     cpu_kernel(iter, [](bool a, bool b) -> bool { return a && b; });
668   } else if (isIntegralType(iter.dtype(), /*includeBool=*/false)) {
669     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "minimum_cpu", [&]() {
670       cpu_kernel_vec(
671           iter,
672           [](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); },
673           [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
674             return at::vec::minimum(a, b);
675           });
676     });
677   } else {
678     AT_DISPATCH_FLOATING_TYPES_AND2(
679         at::ScalarType::Half,
680         at::ScalarType::BFloat16,
681         iter.dtype(),
682         "minimum_cpu",
683         [&]() {
684           cpu_kernel_vec(
685               iter,
686               [](scalar_t a, scalar_t b) -> scalar_t {
687                 if (a != a || b != b) {
688                   return std::numeric_limits<scalar_t>::quiet_NaN();
689                 } else {
690                   return std::min(a, b);
691                 }
692               },
693               [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
694                 return at::vec::minimum(a, b);
695               });
696         });
697   }
698 }
699 
fmax_kernel(TensorIteratorBase & iter)700 void fmax_kernel(TensorIteratorBase& iter) {
701   if (isFloatingType(iter.common_dtype())) {
702     AT_DISPATCH_FLOATING_TYPES_AND2(
703         at::ScalarType::Half,
704         at::ScalarType::BFloat16,
705         iter.common_dtype(),
706         "fmax_cpu",
707         [&]() {
708           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
709             return std::fmax(a, b);
710           });
711         });
712   } else {
713     maximum_kernel(iter);
714   }
715 }
716 
fmin_kernel(TensorIteratorBase & iter)717 void fmin_kernel(TensorIteratorBase& iter) {
718   if (isFloatingType(iter.common_dtype())) {
719     AT_DISPATCH_FLOATING_TYPES_AND2(
720         at::ScalarType::Half,
721         at::ScalarType::BFloat16,
722         iter.common_dtype(),
723         "fmin_cpu",
724         [&]() {
725           cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
726             return std::fmin(a, b);
727           });
728         });
729   } else {
730     minimum_kernel(iter);
731   }
732 }
733 
smooth_l1_kernel(TensorIteratorBase & iter,double beta)734 void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
735   if (iter.dtype() == kBFloat16) {
736     const float beta_val(static_cast<float>(beta));
737     const Vectorized<float> beta_val_vec(beta_val);
738     const Vectorized<float> point_five_vec(static_cast<float>(0.5));
739     cpu_kernel_vec(
740         iter,
741         [&beta_val](BFloat16 a, BFloat16 b) -> BFloat16 {
742           auto z = std::abs(float(a) - float(b));
743           return z < beta_val ? static_cast<float>(0.5) * z * z / beta_val
744                               : z - static_cast<float>(0.5) * beta_val;
745         },
746         [&beta_val_vec, &point_five_vec](
747             Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
748           auto [a0, a1] = convert_bfloat16_float(a);
749           auto [b0, b1] = convert_bfloat16_float(b);
750           auto z = (a0 - b0).abs();
751           a0 = Vectorized<float>::blendv(
752               point_five_vec * z * z / beta_val_vec,
753               z - point_five_vec * beta_val_vec,
754               z >= beta_val_vec);
755           z = (a1 - b1).abs();
756           a1 = Vectorized<float>::blendv(
757               point_five_vec * z * z / beta_val_vec,
758               z - point_five_vec * beta_val_vec,
759               z >= beta_val_vec);
760           return convert_float_bfloat16(a0, a1);
761         });
762   } else {
763     AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_cpu", [&]() {
764       using Vec = Vectorized<scalar_t>;
765       const scalar_t beta_val(beta);
766       const Vec beta_val_vec(beta_val);
767       const Vec point_five_vec(static_cast<scalar_t>(0.5));
768       cpu_kernel_vec(
769           iter,
770           [&beta_val](scalar_t a, scalar_t b) -> scalar_t {
771             auto z = std::abs(a - b);
772             return z < beta_val ? static_cast<scalar_t>(0.5) * z * z / beta_val
773                                 : z - static_cast<scalar_t>(0.5) * beta_val;
774           },
775           [&beta_val_vec, &point_five_vec](Vec a, Vec b) {
776             auto z = (a - b).abs();
777             return Vec::blendv(
778                 point_five_vec * z * z / beta_val_vec,
779                 z - point_five_vec * beta_val_vec,
780                 z >= beta_val_vec);
781           });
782     });
783   }
784 }
785 
huber_kernel(TensorIterator & iter,double delta)786 void huber_kernel(TensorIterator& iter, double delta) {
787   AT_DISPATCH_FLOATING_TYPES_AND2(
788       kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() {
789         using Vec = Vectorized<scalar_t>;
790         const scalar_t delta_val(delta);
791         const Vec delta_val_vec(delta_val);
792         const Vec point_five_vec(static_cast<scalar_t>(0.5));
793         cpu_kernel_vec(
794             iter,
795             [&delta_val](scalar_t a, scalar_t b) -> scalar_t {
796               auto z = std::abs(a - b);
797               return z < delta_val
798                   ? static_cast<scalar_t>(0.5) * z * z
799                   : delta_val * (z - static_cast<scalar_t>(0.5) * delta_val);
800             },
801             [&delta_val_vec, &point_five_vec](Vec a, Vec b) {
802               auto z = (a - b).abs();
803               return Vec::blendv(
804                   point_five_vec * z * z,
805                   delta_val_vec * (z - point_five_vec * delta_val_vec),
806                   z >= delta_val_vec);
807             });
808       });
809 }
810 
sigmoid_backward_kernel(TensorIteratorBase & iter)811 void sigmoid_backward_kernel(TensorIteratorBase& iter) {
812   if (isComplexType(iter.dtype())) {
813     AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cpu", [&]() {
814       auto one_vec = Vectorized<scalar_t>(scalar_t{1});
815       cpu_kernel_vec(
816           iter,
817           [=](scalar_t a, scalar_t b) -> scalar_t {
818             return a * std::conj((scalar_t(1) - b) * b);
819           },
820           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
821             return a * ((one_vec - b) * b).conj();
822           });
823     });
824   } else if (iter.dtype() == kBFloat16) {
825     auto one_vec = Vectorized<float>((float)(1));
826     cpu_kernel_vec(
827         iter,
828         [=](BFloat16 a, BFloat16 b) -> BFloat16 {
829           float a0 = static_cast<float>(a);
830           float b0 = static_cast<float>(b);
831           return a0 * (float(1) - b0) * b0;
832         },
833         [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
834           auto [a0, a1] = convert_bfloat16_float(a);
835           auto [b0, b1] = convert_bfloat16_float(b);
836           a0 = a0 * (one_vec - b0) * b0;
837           a1 = a1 * (one_vec - b1) * b1;
838           return convert_float_bfloat16(a0, a1);
839         });
840   } else {
841     AT_DISPATCH_FLOATING_TYPES_AND(
842         kHalf, iter.dtype(), "sigmoid_backward_cpu", [&]() {
843           auto one_vec = Vectorized<scalar_t>((scalar_t)(1));
844           cpu_kernel_vec(
845               iter,
846               [=](scalar_t a, scalar_t b) -> scalar_t {
847                 return a * (scalar_t(1) - b) * b;
848               },
849               [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
850                 return a * (one_vec - b) * b;
851               });
852         });
853   }
854 }
855 
logit_backward_kernel(TensorIteratorBase & iter,const Scalar & eps_scalar)856 void logit_backward_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
857   AT_DISPATCH_FLOATING_TYPES_AND2(
858       kBFloat16, kHalf, iter.dtype(), "logit_backward_cpu", [&]() {
859         const scalar_t eps = eps_scalar.to<scalar_t>();
860         const Vectorized<scalar_t> kZeroVec(scalar_t(0));
861         const Vectorized<scalar_t> kOneVec(scalar_t(1));
862         if (eps < scalar_t(0)) {
863           const Vectorized<scalar_t> kNanVec(
864               std::numeric_limits<scalar_t>::quiet_NaN());
865           cpu_kernel_vec(
866               iter,
867               [](scalar_t dy, scalar_t x) {
868                 return (x < scalar_t(0) || x > scalar_t(1))
869                     ? std::numeric_limits<scalar_t>::quiet_NaN()
870                     : ((x == scalar_t(0) || x == scalar_t(1))
871                            ? (dy * std::numeric_limits<scalar_t>::infinity())
872                            : (dy / (x * (scalar_t(1) - x))));
873               },
874               [kZeroVec, kOneVec, kNanVec](
875                   Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
876                 return Vectorized<scalar_t>::blendv(
877                     kNanVec,
878                     dy_vec / (x_vec * (kOneVec - x_vec)),
879                     (x_vec >= kZeroVec) & (x_vec <= kOneVec));
880               });
881         } else {
882           const scalar_t lo = eps;
883           const scalar_t hi = scalar_t(1) - eps;
884           const Vectorized<scalar_t> lo_vec(lo);
885           const Vectorized<scalar_t> hi_vec(hi);
886           cpu_kernel_vec(
887               iter,
888               [lo, hi](scalar_t dy, scalar_t x) {
889                 return (x < lo || x > hi)
890                     ? scalar_t(0)
891                     : ((x == scalar_t(0) || x == scalar_t(1))
892                            ? dy * std::numeric_limits<scalar_t>::infinity()
893                            : dy / (x * (scalar_t(1) - x)));
894               },
895               [kZeroVec, kOneVec, lo_vec, hi_vec](
896                   Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
897                 return Vectorized<scalar_t>::blendv(
898                     kZeroVec,
899                     dy_vec / (x_vec * (kOneVec - x_vec)),
900                     (x_vec >= lo_vec) & (x_vec <= hi_vec));
901               });
902         }
903       });
904 }
905 
tanh_backward_kernel(TensorIteratorBase & iter)906 void tanh_backward_kernel(TensorIteratorBase& iter) {
907   if (isComplexType(iter.dtype())) {
908     AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
909       auto one_vec = Vectorized<scalar_t>(scalar_t{1});
910       cpu_kernel_vec(
911           iter,
912           [=](scalar_t a, scalar_t b) -> scalar_t {
913             return a * std::conj(scalar_t{1} - b * b);
914           },
915           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
916             return a * (one_vec - b * b).conj();
917           });
918     });
919   } else if (at::isReducedFloatingType(iter.dtype())) {
920     AT_DISPATCH_REDUCED_FLOATING_TYPES(
921         iter.dtype(), "tanh_backward_cpu", [&]() {
922           auto one_vec = Vectorized<float>(float{1});
923           cpu_kernel_vec(
924               iter,
925               [=](scalar_t a, scalar_t b) -> scalar_t {
926                 float a0 = float(a);
927                 float b0 = float(b);
928                 return a0 * (float{1} - b0 * b0);
929               },
930               [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
931                 auto [a0, a1] = convert_to_float<scalar_t>(a);
932                 auto [b0, b1] = convert_to_float<scalar_t>(b);
933                 a0 = a0 * (one_vec - b0 * b0);
934                 a1 = a1 * (one_vec - b1 * b1);
935                 return convert_from_float<scalar_t>(a0, a1);
936               });
937         });
938   } else {
939     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
940       auto one_vec = Vectorized<scalar_t>(scalar_t{1});
941       cpu_kernel_vec(
942           iter,
943           [=](scalar_t a, scalar_t b) -> scalar_t {
944             return a * (scalar_t{1} - b * b);
945           },
946           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
947             return a * (one_vec - b * b);
948           });
949     });
950   }
951 }
952 
mse_kernel(TensorIteratorBase & iter)953 void mse_kernel(TensorIteratorBase& iter) {
954   if (iter.dtype() == ScalarType::Half) {
955     TORCH_WARN_ONCE(
956         "Applying the CPU mse kernel on half-type tensors. "
957         "This may be slower than using float or double-type tensors.");
958   }
959 
960   AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "mse_cpu", [&]() {
961     cpu_kernel_vec(
962         iter,
963         [=](scalar_t a, scalar_t b) -> scalar_t {
964           auto diff = a - b;
965           return diff * diff;
966         },
967         [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
968           auto diff = a - b;
969           return diff * diff;
970         });
971   });
972 }
973 
fmod_kernel(TensorIteratorBase & iter)974 void fmod_kernel(TensorIteratorBase& iter) {
975   if (isIntegralType(iter.common_dtype(), /*includeBool=*/false)) {
976     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cpu", [&]() {
977       cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t {
978         TORCH_CHECK(d != 0, "ZeroDivisionError");
979         return x % d;
980       });
981     });
982   } else {
983     AT_DISPATCH_FLOATING_TYPES_AND2(
984         kBFloat16, kHalf, iter.common_dtype(), "fmod_cpu", [&]() {
985           cpu_kernel_vec(
986               iter,
987               [](scalar_t x, scalar_t d) -> scalar_t {
988                 return std::fmod(x, d);
989               },
990               [](Vectorized<scalar_t> x, Vectorized<scalar_t> d) {
991                 return x.fmod(d);
992               });
993         });
994   }
995 }
996 
logaddexp_kernel(TensorIteratorBase & iter)997 void logaddexp_kernel(TensorIteratorBase& iter) {
998   if (at::isReducedFloatingType(iter.dtype())) {
999     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1000       using Vec = Vectorized<scalar_t>;
1001       cpu_kernel_vec(
1002           iter,
1003           [=](scalar_t a, scalar_t b) -> scalar_t {
1004             float a0 = static_cast<float>(a);
1005             float b0 = static_cast<float>(b);
1006             if (std::isinf(a0) && a0 == b0) {
1007               return a0;
1008             } else {
1009               float m0 = std::max(a0, b0);
1010               return m0 + std::log1p(std::exp(-std::abs(a0 - b0)));
1011             }
1012           },
1013           [=](Vec a, Vec b) -> Vec {
1014             auto [a0, a1] = convert_to_float<scalar_t>(a);
1015             auto [b0, b1] = convert_to_float<scalar_t>(b);
1016             Vectorized<float> inf(std::numeric_limits<float>::infinity());
1017             Vectorized<float> m0 = maximum(a0, b0);
1018             Vectorized<float> m1 = maximum(a1, b1);
1019             a0 = Vectorized<float>::blendv(
1020                 m0 + (a0 - b0).abs().neg().exp().log1p(),
1021                 a0,
1022                 (a0 == b0) & (a0.abs() == inf));
1023             a1 = Vectorized<float>::blendv(
1024                 m1 + (a1 - b1).abs().neg().exp().log1p(),
1025                 a1,
1026                 (a1 == b1) & (a1.abs() == inf));
1027             return convert_from_float<scalar_t>(a0, a1);
1028           });
1029     });
1030   } else if (isComplexType(iter.dtype())) {
1031     AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1032       cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
1033         return _log_add_exp_helper(a, b);
1034       });
1035     });
1036   } else {
1037     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1038       cpu_kernel_vec(
1039           iter,
1040           [=](scalar_t a, scalar_t b) -> scalar_t {
1041             if (std::isinf(a) && a == b) {
1042               return a;
1043             } else {
1044               scalar_t m = std::max(a, b);
1045               return m + std::log1p(std::exp(-std::abs(a - b)));
1046             }
1047           },
1048           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1049             Vectorized<scalar_t> inf(std::numeric_limits<scalar_t>::infinity());
1050             Vectorized<scalar_t> m = maximum(a, b);
1051             return Vectorized<scalar_t>::blendv(
1052                 m + (a - b).abs().neg().exp().log1p(),
1053                 a,
1054                 (a == b) & (a.abs() == inf));
1055           });
1056     });
1057   }
1058 }
1059 
logaddexp2_kernel(TensorIteratorBase & iter)1060 void logaddexp2_kernel(TensorIteratorBase& iter) {
1061   if (at::isReducedFloatingType(iter.dtype())) {
1062     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "logaddexp2_cpu", [&]() {
1063       using Vec = Vectorized<scalar_t>;
1064       constexpr auto inv_log_2 = static_cast<float>(1.0 / c10::ln_2<double>);
1065       cpu_kernel_vec(
1066           iter,
1067           [=](scalar_t a, scalar_t b) -> scalar_t {
1068             float a0 = static_cast<float>(a);
1069             float b0 = static_cast<float>(b);
1070             if (std::isinf(a0) && a0 == b0) {
1071               return a0;
1072             } else {
1073               float m0 = std::max(a0, b0);
1074               return m0 + std::log1p(std::exp2(-std::abs(a0 - b0))) * inv_log_2;
1075             }
1076           },
1077           [=](Vec a, Vec b) -> Vec {
1078             auto [a0, a1] = convert_to_float<scalar_t>(a);
1079             auto [b0, b1] = convert_to_float<scalar_t>(b);
1080             Vectorized<float> inf(std::numeric_limits<float>::infinity());
1081             Vectorized<float> inv_log_2_vec(inv_log_2);
1082             Vectorized<float> m0 = maximum(a0, b0);
1083             Vectorized<float> m1 = maximum(a1, b1);
1084             a0 = Vectorized<float>::blendv(
1085                 m0 + (a0 - b0).abs().neg().exp2().log1p() * inv_log_2_vec,
1086                 a0,
1087                 (a0 == b0) & (a0.abs() == inf));
1088             a1 = Vectorized<float>::blendv(
1089                 m1 + (a1 - b1).abs().neg().exp2().log1p() * inv_log_2_vec,
1090                 a1,
1091                 (a1 == b1) & (a1.abs() == inf));
1092             return convert_from_float<scalar_t>(a0, a1);
1093           });
1094     });
1095   } else {
1096     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp2_cpu", [&]() {
1097       constexpr auto inv_log_2 = static_cast<scalar_t>(1.0 / c10::ln_2<double>);
1098       cpu_kernel_vec(
1099           iter,
1100           [=](scalar_t a, scalar_t b) -> scalar_t {
1101             if (std::isinf(a) && a == b) {
1102               return a;
1103             } else {
1104               scalar_t m = std::max(a, b);
1105               return m + std::log1p(std::exp2(-std::abs(a - b))) * inv_log_2;
1106             }
1107           },
1108           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1109             Vectorized<scalar_t> inf(std::numeric_limits<scalar_t>::infinity());
1110             Vectorized<scalar_t> inv_log_2_vec(inv_log_2);
1111             Vectorized<scalar_t> m = maximum(a, b);
1112             return Vectorized<scalar_t>::blendv(
1113                 m + (a - b).abs().neg().exp2().log1p() * inv_log_2_vec,
1114                 a,
1115                 (a == b) & (a.abs() == inf));
1116           });
1117     });
1118   }
1119 }
1120 
gcd_kernel(TensorIteratorBase & iter)1121 void gcd_kernel(TensorIteratorBase& iter) {
1122   AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cpu", [&]() {
1123     cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1124       return calc_gcd(a, b);
1125     });
1126   });
1127 }
1128 
lcm_kernel(TensorIteratorBase & iter)1129 void lcm_kernel(TensorIteratorBase& iter) {
1130   AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cpu", [&]() {
1131     cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1132       scalar_t g = calc_gcd(a, b);
1133       return (g == 0) ? 0 : std::abs(a / g * b);
1134     });
1135   });
1136 }
1137 
hypot_kernel(TensorIteratorBase & iter)1138 void hypot_kernel(TensorIteratorBase& iter) {
1139   AT_DISPATCH_FLOATING_TYPES_AND2(
1140       kBFloat16, kHalf, iter.dtype(), "hypot_cpu", [&]() {
1141         cpu_kernel_vec(
1142             iter,
1143             [=](scalar_t a, scalar_t b) -> scalar_t {
1144               return std::hypot(a, b);
1145             },
1146             [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1147               return a.hypot(b);
1148             });
1149       });
1150 }
1151 
igamma_kernel(TensorIteratorBase & iter)1152 void igamma_kernel(TensorIteratorBase& iter) {
1153   AT_DISPATCH_FLOATING_TYPES_AND2(
1154       kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() {
1155         cpu_kernel_vec(
1156             iter,
1157             [=](scalar_t a, scalar_t b) -> scalar_t {
1158               return calc_igamma(a, b);
1159             },
1160             [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1161               return a.igamma(b);
1162             });
1163       });
1164 }
1165 
igammac_kernel(TensorIteratorBase & iter)1166 void igammac_kernel(TensorIteratorBase& iter) {
1167   AT_DISPATCH_FLOATING_TYPES_AND2(
1168       kHalf, kBFloat16, iter.dtype(), "igammac_cpu", [&]() {
1169         cpu_kernel_vec(
1170             iter,
1171             [=](scalar_t a, scalar_t b) -> scalar_t {
1172               return calc_igammac(a, b);
1173             },
1174             [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1175               return a.igammac(b);
1176             });
1177       });
1178 }
1179 
nextafter_kernel(TensorIteratorBase & iter)1180 void nextafter_kernel(TensorIteratorBase& iter) {
1181   if (at::isReducedFloatingType(iter.common_dtype())) {
1182     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() {
1183       cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
1184         return std::nextafter(a, b);
1185       });
1186     });
1187   } else {
1188     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cpu", [&]() {
1189       cpu_kernel_vec(
1190           iter,
1191           [=](scalar_t a, scalar_t b) -> scalar_t {
1192             return std::nextafter(a, b);
1193           },
1194           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1195             return a.nextafter(b);
1196           });
1197     });
1198   }
1199 }
1200 
heaviside_kernel(TensorIteratorBase & iter)1201 void heaviside_kernel(TensorIteratorBase& iter) {
1202   AT_DISPATCH_ALL_TYPES_AND3(
1203       kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
1204         cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1205           return a == 0 ? b : static_cast<scalar_t>(a > 0);
1206         });
1207       });
1208 }
1209 
copysign_kernel(TensorIteratorBase & iter)1210 void copysign_kernel(TensorIteratorBase& iter) {
1211   AT_DISPATCH_FLOATING_TYPES_AND2(
1212       kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
1213         cpu_kernel_vec(
1214             iter,
1215             [](scalar_t a, scalar_t b) -> scalar_t {
1216               return c10::copysign(a, b);
1217             },
1218             [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
1219                 -> Vectorized<scalar_t> { return a.copysign(b); });
1220       });
1221 }
1222 
xlogy_kernel(TensorIteratorBase & iter)1223 void xlogy_kernel(TensorIteratorBase& iter) {
1224   AT_DISPATCH_FLOATING_TYPES_AND2(
1225       kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
1226         cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
1227           if (at::_isnan(y)) {
1228             return NAN;
1229           }
1230           if (x == 0) {
1231             return 0;
1232           }
1233           return x * std::log(y);
1234         });
1235       });
1236 }
1237 
xlog1py_kernel(TensorIteratorBase & iter)1238 void xlog1py_kernel(TensorIteratorBase& iter) {
1239   AT_DISPATCH_FLOATING_TYPES_AND2(
1240       kBFloat16, kHalf, iter.common_dtype(), "xlog1py_cpu", [&]() {
1241         cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
1242           if (at::_isnan(y)) {
1243             return NAN;
1244           }
1245           if (x == 0) {
1246             return 0;
1247           }
1248           return x * std::log1p(y);
1249         });
1250       });
1251 }
1252 
zeta_kernel(TensorIteratorBase & iter)1253 void zeta_kernel(TensorIteratorBase& iter) {
1254   AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cpu", [&]() {
1255     cpu_kernel(
1256         iter, [](scalar_t x, scalar_t q) -> scalar_t { return zeta(x, q); });
1257   });
1258 }
1259 
chebyshev_polynomial_t_kernel(TensorIteratorBase & iterator)1260 void chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator) {
1261   AT_DISPATCH_FLOATING_TYPES(
1262       iterator.common_dtype(), "chebyshev_polynomial_t_cpu", [&]() {
1263         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1264           return chebyshev_polynomial_t_forward(x, n);
1265         });
1266       });
1267 } // chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator)
1268 
chebyshev_polynomial_u_kernel(TensorIteratorBase & iterator)1269 void chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator) {
1270   AT_DISPATCH_FLOATING_TYPES(
1271       iterator.common_dtype(), "chebyshev_polynomial_u_cpu", [&]() {
1272         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1273           return chebyshev_polynomial_u_forward(x, n);
1274         });
1275       });
1276 } // chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator)
1277 
chebyshev_polynomial_v_kernel(TensorIteratorBase & iterator)1278 void chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator) {
1279   AT_DISPATCH_FLOATING_TYPES(
1280       iterator.common_dtype(), "chebyshev_polynomial_v_cpu", [&]() {
1281         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1282           return chebyshev_polynomial_v_forward(x, n);
1283         });
1284       });
1285 } // chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator)
1286 
chebyshev_polynomial_w_kernel(TensorIteratorBase & iterator)1287 void chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator) {
1288   AT_DISPATCH_FLOATING_TYPES(
1289       iterator.common_dtype(), "chebyshev_polynomial_w_cpu", [&]() {
1290         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1291           return chebyshev_polynomial_w_forward(x, n);
1292         });
1293       });
1294 } // chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator)
1295 
hermite_polynomial_h_kernel(TensorIteratorBase & iterator)1296 void hermite_polynomial_h_kernel(TensorIteratorBase& iterator) {
1297   AT_DISPATCH_FLOATING_TYPES(
1298       iterator.common_dtype(), "hermite_polynomial_h_cpu", [&]() {
1299         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1300           return hermite_polynomial_h_forward(x, n);
1301         });
1302       });
1303 } // hermite_polynomial_h_kernel(TensorIteratorBase& iterator)
1304 
hermite_polynomial_he_kernel(TensorIteratorBase & iterator)1305 void hermite_polynomial_he_kernel(TensorIteratorBase& iterator) {
1306   AT_DISPATCH_FLOATING_TYPES(
1307       iterator.common_dtype(), "hermite_polynomial_he_cpu", [&]() {
1308         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1309           return hermite_polynomial_he_forward(x, n);
1310         });
1311       });
1312 } // hermite_polynomial_he_kernel(TensorIteratorBase& iterator)
1313 
laguerre_polynomial_l_kernel(TensorIteratorBase & iterator)1314 void laguerre_polynomial_l_kernel(TensorIteratorBase& iterator) {
1315   AT_DISPATCH_FLOATING_TYPES(
1316       iterator.common_dtype(), "laguerre_polynomial_l_cpu", [&]() {
1317         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1318           return laguerre_polynomial_l_forward(x, n);
1319         });
1320       });
1321 } // laguerre_polynomial_l_kernel(TensorIteratorBase& iterator)
1322 
legendre_polynomial_p_kernel(TensorIteratorBase & iterator)1323 void legendre_polynomial_p_kernel(TensorIteratorBase& iterator) {
1324   AT_DISPATCH_FLOATING_TYPES(
1325       iterator.common_dtype(), "legendre_polynomial_p_cpu", [&]() {
1326         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1327           return legendre_polynomial_p_forward(x, n);
1328         });
1329       });
1330 } // legendre_polynomial_p_kernel(TensorIteratorBase& iterator)
1331 
shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase & iterator)1332 void shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator) {
1333   AT_DISPATCH_FLOATING_TYPES(
1334       iterator.common_dtype(), "shifted_chebyshev_polynomial_t_cpu", [&]() {
1335         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1336           return shifted_chebyshev_polynomial_t_forward(x, n);
1337         });
1338       });
1339 } // shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator)
1340 
shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase & iterator)1341 void shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator) {
1342   AT_DISPATCH_FLOATING_TYPES(
1343       iterator.common_dtype(), "shifted_chebyshev_polynomial_u_cpu", [&]() {
1344         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1345           return shifted_chebyshev_polynomial_u_forward(x, n);
1346         });
1347       });
1348 } // shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator)
1349 
shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase & iterator)1350 void shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator) {
1351   AT_DISPATCH_FLOATING_TYPES(
1352       iterator.common_dtype(), "shifted_chebyshev_polynomial_v_cpu", [&]() {
1353         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1354           return shifted_chebyshev_polynomial_v_forward(x, n);
1355         });
1356       });
1357 } // shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator)
1358 
shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase & iterator)1359 void shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator) {
1360   AT_DISPATCH_FLOATING_TYPES(
1361       iterator.common_dtype(), "shifted_chebyshev_polynomial_w_cpu", [&]() {
1362         cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1363           return shifted_chebyshev_polynomial_w_forward(x, n);
1364         });
1365       });
1366 } // shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator)
1367 
1368 } // namespace
1369 
1370 REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
1371 REGISTER_DISPATCH(mul_stub, &mul_kernel);
1372 REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
1373 REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
1374 REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel);
1375 REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
1376 REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel);
1377 REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
1378 REGISTER_DISPATCH(lshift_stub, &lshift_kernel);
1379 REGISTER_DISPATCH(rshift_stub, &rshift_kernel);
1380 REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
1381 REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);
1382 REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel);
1383 REGISTER_DISPATCH(lt_stub, &lt_kernel);
1384 REGISTER_DISPATCH(le_stub, &le_kernel);
1385 REGISTER_DISPATCH(gt_stub, &gt_kernel);
1386 REGISTER_DISPATCH(ge_stub, &ge_kernel);
1387 REGISTER_DISPATCH(eq_stub, &eq_kernel);
1388 REGISTER_DISPATCH(ne_stub, &ne_kernel);
1389 REGISTER_DISPATCH(maximum_stub, &maximum_kernel);
1390 REGISTER_DISPATCH(minimum_stub, &minimum_kernel);
1391 REGISTER_DISPATCH(fmax_stub, &fmax_kernel);
1392 REGISTER_DISPATCH(fmin_stub, &fmin_kernel);
1393 REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
1394 REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
1395 REGISTER_DISPATCH(fmod_stub, &fmod_kernel);
1396 REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
1397 REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
1398 REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
1399 REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
1400 REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
1401 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
1402 REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
1403 REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_kernel);
1404 REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_kernel);
1405 REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_kernel);
1406 REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel);
1407 REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel);
1408 REGISTER_DISPATCH(
1409     shifted_chebyshev_polynomial_t_stub,
1410     &shifted_chebyshev_polynomial_t_kernel);
1411 REGISTER_DISPATCH(
1412     shifted_chebyshev_polynomial_u_stub,
1413     &shifted_chebyshev_polynomial_u_kernel);
1414 REGISTER_DISPATCH(
1415     shifted_chebyshev_polynomial_v_stub,
1416     &shifted_chebyshev_polynomial_v_kernel);
1417 REGISTER_DISPATCH(
1418     shifted_chebyshev_polynomial_w_stub,
1419     &shifted_chebyshev_polynomial_w_kernel);
1420 // Might enable AVX512 dispatch after enabling explicit vectorization for them.
1421 REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel);
1422 REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel);
1423 REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel);
1424 
1425 ALSO_REGISTER_AVX512_DISPATCH(atan2_stub, &atan2_kernel);
1426 ALSO_REGISTER_AVX512_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
1427 ALSO_REGISTER_AVX512_DISPATCH(huber_stub, &huber_kernel);
1428 ALSO_REGISTER_AVX512_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
1429 ALSO_REGISTER_AVX512_DISPATCH(logit_backward_stub, &logit_backward_kernel);
1430 ALSO_REGISTER_AVX512_DISPATCH(tanh_backward_stub, &tanh_backward_kernel);
1431 ALSO_REGISTER_AVX512_DISPATCH(mse_stub, &mse_kernel);
1432 ALSO_REGISTER_AVX512_DISPATCH(logaddexp_stub, &logaddexp_kernel);
1433 ALSO_REGISTER_AVX512_DISPATCH(logaddexp2_stub, &logaddexp2_kernel);
1434 ALSO_REGISTER_AVX512_DISPATCH(hypot_stub, &hypot_kernel);
1435 ALSO_REGISTER_AVX512_DISPATCH(igamma_stub, &igamma_kernel);
1436 ALSO_REGISTER_AVX512_DISPATCH(igammac_stub, &igammac_kernel);
1437 
1438 } // namespace at::native
1439