1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/BinaryOps.h>
3
4 #include <cmath>
5
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <ATen/cpu/vec/vec.h>
12 #include <ATen/native/Math.h>
13 #include <ATen/native/TensorIterator.h>
14 #include <ATen/native/cpu/LogAddExp.h>
15 #include <ATen/native/cpu/Loops.h>
16 #include <c10/macros/Macros.h>
17 #include <c10/util/TypeSafeSignMath.h>
18 #include <c10/util/generic_math.h>
19
20 namespace at::native {
21
22 namespace {
23
24 using namespace vec;
25
26 template <
27 typename scalar_t,
28 typename Op,
29 typename opmath_t = at::opmath_type<scalar_t>,
30 typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
binary_op_scalar(const Vectorized<scalar_t> & a,opmath_t b,const Op & op)31 inline Vectorized<scalar_t> binary_op_scalar(
32 const Vectorized<scalar_t>& a,
33 opmath_t b,
34 const Op& op) {
35 Vectorized<opmath_t> vec_b(b);
36 auto [a0, a1] = convert_to_float<scalar_t>(a);
37 return convert_from_float<scalar_t>(op(a0, vec_b), op(a1, vec_b));
38 }
39
add_clamp_kernel(TensorIterator & iter,const Scalar & alpha_scalar,const Scalar & min_val,const Scalar & max_val)40 void add_clamp_kernel(
41 TensorIterator& iter,
42 const Scalar& alpha_scalar,
43 const Scalar& min_val,
44 const Scalar& max_val) {
45 AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_clamp_cpu", [&]() {
46 auto alpha = alpha_scalar.to<scalar_t>();
47 auto alpha_vec = Vectorized<scalar_t>(alpha);
48 auto min_scalar = min_val.to<scalar_t>();
49 auto min_vec = Vectorized<scalar_t>(min_scalar);
50 auto max_scalar = max_val.to<scalar_t>();
51 auto max_vec = Vectorized<scalar_t>(max_scalar);
52 cpu_kernel_vec(
53 iter,
54 [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t {
55 return std::min(
56 max_scalar,
57 std::max(min_scalar, static_cast<scalar_t>(a + alpha * b)));
58 },
59 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
60 __ubsan_ignore_undefined__ {
61 auto add_clamp_res = vec::fmadd(b, alpha_vec, a);
62 add_clamp_res = vec::clamp_min(add_clamp_res, min_vec);
63 add_clamp_res = vec::clamp_max(add_clamp_res, max_vec);
64 return add_clamp_res;
65 });
66 });
67 }
68
atan2_kernel(TensorIteratorBase & iter)69 void atan2_kernel(TensorIteratorBase& iter) {
70 AT_DISPATCH_FLOATING_TYPES_AND2(
71 kBFloat16, kHalf, iter.dtype(), "atan2_cpu", [&]() {
72 cpu_kernel_vec(
73 iter,
74 [=](scalar_t a, scalar_t b) -> scalar_t {
75 return std::atan2(a, b);
76 },
77 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
78 return a.atan2(b);
79 });
80 });
81 }
82
83 #if !defined(C10_MOBILE)
84 #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
85 AT_DISPATCH_V2( \
86 TYPE, \
87 NAME, \
88 AT_WRAP(__VA_ARGS__), \
89 kComplexHalf, \
90 kHalf, \
91 kBool, \
92 kBFloat16, \
93 AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
94 #define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
95 AT_DISPATCH_V2( \
96 TYPE, \
97 NAME, \
98 AT_WRAP(__VA_ARGS__), \
99 kComplexHalf, \
100 kHalf, \
101 kBFloat16, \
102 AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
103 #define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
104 AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
105 kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
106 #else
107 #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
108 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
109 kComplexHalf, kHalf, kBool, kBFloat16, TYPE, NAME, __VA_ARGS__)
110 #define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
111 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
112 kComplexHalf, kHalf, kBFloat16, TYPE, NAME, __VA_ARGS__)
113 #define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
114 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
115 kHalf, kBFloat16, TYPE, NAME, __VA_ARGS__)
116 #endif
117
mul_kernel(TensorIteratorBase & iter)118 void mul_kernel(TensorIteratorBase& iter) {
119 auto dtype = iter.common_dtype();
120 if (dtype == ScalarType::Bool) {
121 cpu_kernel(iter, [=](bool a, bool b) -> bool { return a && b; });
122 } else if (dtype == kComplexHalf) {
123 cpu_kernel(
124 iter,
125 [=](c10::complex<at::Half> a,
126 c10::complex<at::Half> b) -> c10::complex<at::Half> {
127 using comp_t = c10::complex<float>;
128 return comp_t{a} * comp_t{b};
129 });
130 } else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
131 AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "mul_cpu_reduced_float", [&]() {
132 using opmath_t = at::opmath_type<scalar_t>;
133 opmath_t b = iter.original_scalar_value<opmath_t>(2);
134 iter.remove_operand(2);
135 cpu_kernel_vec(
136 iter,
137 [=](scalar_t a) __ubsan_ignore_undefined__ -> scalar_t {
138 return static_cast<opmath_t>(a) * b;
139 },
140 [=](Vectorized<scalar_t> a) __ubsan_ignore_undefined__ {
141 return binary_op_scalar(
142 a,
143 b,
144 [](const Vectorized<opmath_t>& x,
145 const Vectorized<opmath_t>& y) { return x * y; });
146 });
147 });
148 } else {
149 _AT_DISPATCH_MUL_TYPES(dtype, "mul_cpu", [&]() {
150 cpu_kernel_vec(
151 iter,
152 [=](scalar_t a, scalar_t b)
153 __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
154 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
155 __ubsan_ignore_undefined__ { return a * b; });
156 });
157 }
158 }
159
div_true_kernel(TensorIteratorBase & iter)160 void div_true_kernel(TensorIteratorBase& iter) {
161 const auto dtype = iter.common_dtype();
162 if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
163 AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_cpu_reduced_float", [&]() {
164 using opmath_t = at::opmath_type<scalar_t>;
165 opmath_t b = iter.original_scalar_value<opmath_t>(2);
166 iter.remove_operand(2);
167 cpu_kernel_vec(
168 iter,
169 [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
170 return static_cast<opmath_t>(a) / b;
171 },
172 [=](Vectorized<scalar_t> a) {
173 return binary_op_scalar(
174 a,
175 b,
176 [](const Vectorized<opmath_t>& x,
177 const Vectorized<opmath_t>& y) { return x / y; });
178 });
179 });
180 } else {
181 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
182 kBFloat16, kHalf, dtype, "div_cpu", [&]() {
183 cpu_kernel_vec(
184 iter,
185 [](scalar_t a, scalar_t b)
186 __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
187 return a / b;
188 },
189 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
190 return a / b;
191 });
192 });
193 }
194 }
195
div_trunc_kernel(TensorIteratorBase & iter)196 void div_trunc_kernel(TensorIteratorBase& iter) {
197 const auto dtype = iter.common_dtype();
198 if (isIntegralType(dtype, /*includeBool*/ false)) {
199 // There's no SIMD integer division, so don't try to vectorize it.
200 // TODO: if the divisor is a scalar, rewrite as multiplication by a
201 // constant.
202 AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cpu", [&]() {
203 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
204 TORCH_CHECK(b != 0, "ZeroDivisionError");
205 return a / b;
206 });
207 });
208 } else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
209 AT_DISPATCH_REDUCED_FLOATING_TYPES(
210 dtype, "div_trunc_cpu_reduced_float", [&]() {
211 using opmath_t = at::opmath_type<scalar_t>;
212 opmath_t b = iter.original_scalar_value<opmath_t>(2);
213 iter.remove_operand(2);
214 cpu_kernel_vec(
215 iter,
216 [=](scalar_t a)
217 __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
218 return std::trunc(static_cast<opmath_t>(a) / b);
219 },
220 [=](Vectorized<scalar_t> a) {
221 return binary_op_scalar(
222 a,
223 b,
224 [](const Vectorized<opmath_t>& x,
225 const Vectorized<opmath_t>& y) {
226 return (x / y).trunc();
227 });
228 });
229 });
230 } else {
231 AT_DISPATCH_FLOATING_TYPES_AND2(
232 kBFloat16, kHalf, dtype, "div_trunc_cpu", [&]() {
233 cpu_kernel_vec(
234 iter,
235 [](scalar_t a, scalar_t b)
236 __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
237 return std::trunc(a / b);
238 },
239 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
240 return (a / b).trunc();
241 });
242 });
243 }
244 }
245
246 template <typename scalar_t>
div_floor_floating_vec(const Vectorized<scalar_t> & a,const Vectorized<scalar_t> & b)247 inline Vectorized<scalar_t> div_floor_floating_vec(
248 const Vectorized<scalar_t>& a,
249 const Vectorized<scalar_t>& b) {
250 using vec_t = Vectorized<scalar_t>;
251 const auto basic_div = a / b;
252 vec_t inf(std::numeric_limits<scalar_t>::infinity());
253 auto mod = a.fmod(b);
254 // Fixup for a case that isn't properly handled by Sleef_fmod
255 auto floor = vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
256 auto div = floor / b;
257 const auto zero = vec_t(0);
258 auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
259 const auto one = vec_t(1);
260 div = vec_t::blendv(div, div - one, mask);
261 auto floordiv = div.floor();
262 mask = (div - floordiv) > vec_t(0.5);
263 floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
264 floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
265 floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
266 return floordiv;
267 };
268
div_floor_kernel(TensorIteratorBase & iter)269 void div_floor_kernel(TensorIteratorBase& iter) {
270 const auto dtype = iter.common_dtype();
271 if (dtype == kByte) {
272 // In the special case of unsigned integer division, floor division is
273 // equivalent to truncation division (since the signs of the divisor and
274 // dividend are always the same)
275 return div_trunc_kernel(iter);
276 } else if (isIntegralType(dtype, /*includeBool*/ false)) {
277 // There's no SIMD integer division, so don't try to vectorize it.
278 AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cpu", [&]() {
279 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
280 TORCH_CHECK(b != 0, "ZeroDivisionError");
281 return c10::div_floor_integer(a, b);
282 });
283 });
284 } else {
285 // See NOTE: [Floor Division in Python]
286 if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
287 AT_DISPATCH_REDUCED_FLOATING_TYPES(
288 dtype, "div_floor_cpu_reduced_float", [&]() {
289 using opmath_t = at::opmath_type<scalar_t>;
290 opmath_t b = iter.original_scalar_value<opmath_t>(2);
291 iter.remove_operand(2);
292 using vec_t = Vectorized<opmath_t>;
293 cpu_kernel_vec(
294 iter,
295 [=](scalar_t a) -> scalar_t {
296 return c10::div_floor_floating(static_cast<opmath_t>(a), b);
297 },
298 [=](Vectorized<scalar_t> a) {
299 return binary_op_scalar(
300 a, b, [](const vec_t& x, const vec_t& y) {
301 return div_floor_floating_vec(x, y);
302 });
303 });
304 });
305 } else {
306 AT_DISPATCH_FLOATING_TYPES_AND2(
307 kBFloat16, kHalf, dtype, "div_floor_cpu", [&]() {
308 using vec_t = Vectorized<scalar_t>;
309 cpu_kernel_vec(
310 iter,
311 [](scalar_t a, scalar_t b) -> scalar_t {
312 return c10::div_floor_floating(a, b);
313 },
314 [](vec_t a, vec_t b) -> vec_t {
315 return div_floor_floating_vec(a, b);
316 });
317 });
318 }
319 }
320 }
321
remainder_kernel(TensorIteratorBase & iter)322 void remainder_kernel(TensorIteratorBase& iter) {
323 if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) {
324 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() {
325 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
326 TORCH_CHECK(b != 0, "ZeroDivisionError");
327 scalar_t r = a % b;
328 if ((r != 0) && (c10::is_negative(r) != c10::is_negative(b))) {
329 r += b;
330 }
331 return r;
332 });
333 });
334 } else if (iter.common_dtype() == kBFloat16) {
335 cpu_kernel_vec(
336 iter,
337 [=](BFloat16 a, BFloat16 b)
338 __ubsan_ignore_float_divide_by_zero__ -> BFloat16 {
339 float a0 = static_cast<float>(a);
340 float b0 = static_cast<float>(b);
341 float mod0 = std::fmod(a0, b0);
342 if ((mod0 != 0) && ((b0 < 0) != (mod0 < 0))) {
343 mod0 += b0;
344 }
345 return mod0;
346 },
347 [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
348 auto [a0, a1] = convert_bfloat16_float(a);
349 auto [b0, b1] = convert_bfloat16_float(b);
350 auto mod0 = a0.fmod(b0);
351 auto mod1 = a1.fmod(b1);
352 const auto zero = Vectorized<float>(0);
353 auto mask0 = (mod0 != zero) & ((b0 < zero) ^ (mod0 < zero));
354 auto mask1 = (mod1 != zero) & ((b1 < zero) ^ (mod1 < zero));
355 a0 = Vectorized<float>::blendv(mod0, mod0 + b0, mask0);
356 a1 = Vectorized<float>::blendv(mod1, mod1 + b1, mask1);
357 return convert_float_bfloat16(a0, a1);
358 });
359 } else {
360 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
361 iter.common_dtype(), "remainder_cpu", [&]() {
362 cpu_kernel_vec(
363 iter,
364 [=](scalar_t a, scalar_t b)
365 __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
366 scalar_t mod = std::fmod(a, b);
367 if ((mod != 0) && ((b < 0) != (mod < 0)))
368 mod += b;
369 return mod;
370 },
371 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
372 auto mod = a.fmod(b);
373 const auto zero = Vectorized<scalar_t>(0);
374 auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
375 return Vectorized<scalar_t>::blendv(mod, mod + b, mask);
376 });
377 });
378 }
379 }
380
bitwise_and_kernel(TensorIteratorBase & iter)381 void bitwise_and_kernel(TensorIteratorBase& iter) {
382 if (iter.dtype() == ScalarType::Bool) {
383 cpu_kernel(iter, [](bool a, bool b) { return a && b; });
384 } else {
385 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
386 cpu_kernel_vec(
387 iter,
388 [](scalar_t a, scalar_t b) -> scalar_t { return a & b; },
389 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a & b; });
390 });
391 }
392 }
393
bitwise_or_kernel(TensorIteratorBase & iter)394 void bitwise_or_kernel(TensorIteratorBase& iter) {
395 if (iter.dtype() == ScalarType::Bool) {
396 cpu_kernel(iter, [](bool a, bool b) { return a || b; });
397 } else {
398 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() {
399 cpu_kernel_vec(
400 iter,
401 [](scalar_t a, scalar_t b) -> scalar_t { return a | b; },
402 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a | b; });
403 });
404 }
405 }
406
bitwise_xor_kernel(TensorIteratorBase & iter)407 void bitwise_xor_kernel(TensorIteratorBase& iter) {
408 if (iter.dtype() == ScalarType::Bool) {
409 // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps
410 // this operation for both Boolean and integral types.
411 cpu_kernel(iter, [](bool a, bool b) { return a != b; });
412 } else {
413 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cpu", [&]() {
414 cpu_kernel_vec(
415 iter,
416 [](scalar_t a, scalar_t b) -> scalar_t { return a ^ b; },
417 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a ^ b; });
418 });
419 }
420 }
421
lshift_kernel(TensorIteratorBase & iter)422 void lshift_kernel(TensorIteratorBase& iter) {
423 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
424 cpu_kernel_vec(
425 iter,
426 [](scalar_t a, scalar_t b) -> scalar_t {
427 constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
428 if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
429 (b >= max_shift)) {
430 return 0;
431 }
432 return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
433 },
434 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a << b; });
435 });
436 }
437
logical_and_kernel(TensorIterator & iter)438 void logical_and_kernel(TensorIterator& iter) {
439 // See Note [special-case bool outputs]
440 if (iter.dtype() == ScalarType::Bool) {
441 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
442 kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
443 cpu_kernel(
444 iter, [](scalar_t a, scalar_t b) -> bool { return a && b; });
445 });
446 } else {
447 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
448 kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
449 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
450 return static_cast<scalar_t>(a && b);
451 });
452 });
453 }
454 }
455
logical_or_kernel(TensorIterator & iter)456 void logical_or_kernel(TensorIterator& iter) {
457 // See Note [special-case bool outputs]
458 if (iter.dtype() == ScalarType::Bool) {
459 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
460 kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
461 cpu_kernel(
462 iter, [](scalar_t a, scalar_t b) -> bool { return a || b; });
463 });
464 } else {
465 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
466 kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
467 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
468 return static_cast<scalar_t>(a || b);
469 });
470 });
471 }
472 }
473
logical_xor_kernel(TensorIterator & iter)474 void logical_xor_kernel(TensorIterator& iter) {
475 // See Note [special-case bool outputs]
476 if (iter.dtype() == ScalarType::Bool) {
477 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
478 kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
479 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool {
480 return bool(a) != bool(b);
481 });
482 });
483 } else {
484 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
485 kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
486 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
487 return static_cast<scalar_t>(bool(a) != bool(b));
488 });
489 });
490 }
491 }
492
rshift_kernel(TensorIteratorBase & iter)493 void rshift_kernel(TensorIteratorBase& iter) {
494 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() {
495 cpu_kernel_vec(
496 iter,
497 [](scalar_t a, scalar_t b) -> scalar_t {
498 // right shift value to retain sign bit for signed and no bits for
499 // unsigned
500 constexpr scalar_t max_shift =
501 sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
502 if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
503 (b >= max_shift)) {
504 return a >> max_shift;
505 }
506 return a >> b;
507 },
508 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a >> b; });
509 });
510 }
511
lt_kernel(TensorIteratorBase & iter)512 void lt_kernel(TensorIteratorBase& iter) {
513 // See Note [special-case bool outputs]
514 if (iter.dtype() == ScalarType::Bool) {
515 AT_DISPATCH_ALL_TYPES_AND3(
516 kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
517 cpu_kernel(
518 iter, [](scalar_t a, scalar_t b) -> bool { return a < b; });
519 });
520 } else {
521 AT_DISPATCH_ALL_TYPES_AND2(
522 kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
523 cpu_kernel_vec(
524 iter,
525 [](scalar_t a, scalar_t b) -> scalar_t { return a < b; },
526 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
527 -> Vectorized<scalar_t> { return a.lt(b); });
528 });
529 }
530 }
531
le_kernel(TensorIteratorBase & iter)532 void le_kernel(TensorIteratorBase& iter) {
533 // See Note [special-case bool outputs]
534 if (iter.dtype() == ScalarType::Bool) {
535 AT_DISPATCH_ALL_TYPES_AND3(
536 kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
537 cpu_kernel(
538 iter, [](scalar_t a, scalar_t b) -> bool { return a <= b; });
539 });
540 } else {
541 AT_DISPATCH_ALL_TYPES_AND2(
542 kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
543 cpu_kernel_vec(
544 iter,
545 [](scalar_t a, scalar_t b) -> scalar_t { return a <= b; },
546 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
547 -> Vectorized<scalar_t> { return a.le(b); });
548 });
549 }
550 }
551
gt_kernel(TensorIteratorBase & iter)552 void gt_kernel(TensorIteratorBase& iter) {
553 // See Note [special-case bool outputs]
554 if (iter.dtype() == ScalarType::Bool) {
555 AT_DISPATCH_ALL_TYPES_AND3(
556 kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
557 cpu_kernel(
558 iter, [](scalar_t a, scalar_t b) -> bool { return a > b; });
559 });
560 } else {
561 AT_DISPATCH_ALL_TYPES_AND2(
562 kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
563 cpu_kernel_vec(
564 iter,
565 [](scalar_t a, scalar_t b) -> scalar_t { return a > b; },
566 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
567 -> Vectorized<scalar_t> { return a.gt(b); });
568 });
569 }
570 }
571
ge_kernel(TensorIteratorBase & iter)572 void ge_kernel(TensorIteratorBase& iter) {
573 // See Note [special-case bool outputs]
574 if (iter.dtype() == ScalarType::Bool) {
575 AT_DISPATCH_ALL_TYPES_AND3(
576 kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
577 cpu_kernel(
578 iter, [](scalar_t a, scalar_t b) -> bool { return a >= b; });
579 });
580 } else {
581 AT_DISPATCH_ALL_TYPES_AND2(
582 kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
583 cpu_kernel_vec(
584 iter,
585 [](scalar_t a, scalar_t b) -> scalar_t { return a >= b; },
586 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
587 -> Vectorized<scalar_t> { return a.ge(b); });
588 });
589 }
590 }
591
eq_kernel(TensorIteratorBase & iter)592 void eq_kernel(TensorIteratorBase& iter) {
593 // See Note [special-case bool outputs]
594 if (iter.dtype() == ScalarType::Bool) {
595 _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
596 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a == b; });
597 });
598 } else {
599 _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
600 cpu_kernel_vec(
601 iter,
602 [](scalar_t a, scalar_t b) -> scalar_t {
603 return static_cast<scalar_t>(a == b);
604 },
605 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
606 -> Vectorized<scalar_t> { return a.eq(b); });
607 });
608 }
609 }
610
ne_kernel(TensorIteratorBase & iter)611 void ne_kernel(TensorIteratorBase& iter) {
612 // See Note [special-case bool outputs]
613 if (iter.dtype() == ScalarType::Bool) {
614 _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
615 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a != b; });
616 });
617 } else {
618 _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
619 cpu_kernel_vec(
620 iter,
621 [](scalar_t a, scalar_t b) -> scalar_t {
622 return static_cast<scalar_t>(a != b);
623 },
624 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
625 -> Vectorized<scalar_t> { return a.ne(b); });
626 });
627 }
628 }
629
maximum_kernel(TensorIteratorBase & iter)630 void maximum_kernel(TensorIteratorBase& iter) {
631 if (iter.dtype() == ScalarType::Bool) {
632 cpu_kernel(iter, [](bool a, bool b) -> bool { return a || b; });
633 } else if (isIntegralType(iter.dtype(), /*includeBool=*/false)) {
634 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "maximum_cpu", [&]() {
635 cpu_kernel_vec(
636 iter,
637 [](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); },
638 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
639 return at::vec::maximum(a, b);
640 });
641 });
642 } else {
643 AT_DISPATCH_FLOATING_TYPES_AND2(
644 at::ScalarType::Half,
645 at::ScalarType::BFloat16,
646 iter.dtype(),
647 "maximum_cpu",
648 [&]() {
649 cpu_kernel_vec(
650 iter,
651 [](scalar_t a, scalar_t b) -> scalar_t {
652 if (a != a || b != b) {
653 return std::numeric_limits<scalar_t>::quiet_NaN();
654 } else {
655 return std::max(a, b);
656 }
657 },
658 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
659 return at::vec::maximum(a, b);
660 });
661 });
662 }
663 }
664
minimum_kernel(TensorIteratorBase & iter)665 void minimum_kernel(TensorIteratorBase& iter) {
666 if (iter.dtype() == ScalarType::Bool) {
667 cpu_kernel(iter, [](bool a, bool b) -> bool { return a && b; });
668 } else if (isIntegralType(iter.dtype(), /*includeBool=*/false)) {
669 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "minimum_cpu", [&]() {
670 cpu_kernel_vec(
671 iter,
672 [](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); },
673 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
674 return at::vec::minimum(a, b);
675 });
676 });
677 } else {
678 AT_DISPATCH_FLOATING_TYPES_AND2(
679 at::ScalarType::Half,
680 at::ScalarType::BFloat16,
681 iter.dtype(),
682 "minimum_cpu",
683 [&]() {
684 cpu_kernel_vec(
685 iter,
686 [](scalar_t a, scalar_t b) -> scalar_t {
687 if (a != a || b != b) {
688 return std::numeric_limits<scalar_t>::quiet_NaN();
689 } else {
690 return std::min(a, b);
691 }
692 },
693 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
694 return at::vec::minimum(a, b);
695 });
696 });
697 }
698 }
699
fmax_kernel(TensorIteratorBase & iter)700 void fmax_kernel(TensorIteratorBase& iter) {
701 if (isFloatingType(iter.common_dtype())) {
702 AT_DISPATCH_FLOATING_TYPES_AND2(
703 at::ScalarType::Half,
704 at::ScalarType::BFloat16,
705 iter.common_dtype(),
706 "fmax_cpu",
707 [&]() {
708 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
709 return std::fmax(a, b);
710 });
711 });
712 } else {
713 maximum_kernel(iter);
714 }
715 }
716
fmin_kernel(TensorIteratorBase & iter)717 void fmin_kernel(TensorIteratorBase& iter) {
718 if (isFloatingType(iter.common_dtype())) {
719 AT_DISPATCH_FLOATING_TYPES_AND2(
720 at::ScalarType::Half,
721 at::ScalarType::BFloat16,
722 iter.common_dtype(),
723 "fmin_cpu",
724 [&]() {
725 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
726 return std::fmin(a, b);
727 });
728 });
729 } else {
730 minimum_kernel(iter);
731 }
732 }
733
smooth_l1_kernel(TensorIteratorBase & iter,double beta)734 void smooth_l1_kernel(TensorIteratorBase& iter, double beta) {
735 if (iter.dtype() == kBFloat16) {
736 const float beta_val(static_cast<float>(beta));
737 const Vectorized<float> beta_val_vec(beta_val);
738 const Vectorized<float> point_five_vec(static_cast<float>(0.5));
739 cpu_kernel_vec(
740 iter,
741 [&beta_val](BFloat16 a, BFloat16 b) -> BFloat16 {
742 auto z = std::abs(float(a) - float(b));
743 return z < beta_val ? static_cast<float>(0.5) * z * z / beta_val
744 : z - static_cast<float>(0.5) * beta_val;
745 },
746 [&beta_val_vec, &point_five_vec](
747 Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
748 auto [a0, a1] = convert_bfloat16_float(a);
749 auto [b0, b1] = convert_bfloat16_float(b);
750 auto z = (a0 - b0).abs();
751 a0 = Vectorized<float>::blendv(
752 point_five_vec * z * z / beta_val_vec,
753 z - point_five_vec * beta_val_vec,
754 z >= beta_val_vec);
755 z = (a1 - b1).abs();
756 a1 = Vectorized<float>::blendv(
757 point_five_vec * z * z / beta_val_vec,
758 z - point_five_vec * beta_val_vec,
759 z >= beta_val_vec);
760 return convert_float_bfloat16(a0, a1);
761 });
762 } else {
763 AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_cpu", [&]() {
764 using Vec = Vectorized<scalar_t>;
765 const scalar_t beta_val(beta);
766 const Vec beta_val_vec(beta_val);
767 const Vec point_five_vec(static_cast<scalar_t>(0.5));
768 cpu_kernel_vec(
769 iter,
770 [&beta_val](scalar_t a, scalar_t b) -> scalar_t {
771 auto z = std::abs(a - b);
772 return z < beta_val ? static_cast<scalar_t>(0.5) * z * z / beta_val
773 : z - static_cast<scalar_t>(0.5) * beta_val;
774 },
775 [&beta_val_vec, &point_five_vec](Vec a, Vec b) {
776 auto z = (a - b).abs();
777 return Vec::blendv(
778 point_five_vec * z * z / beta_val_vec,
779 z - point_five_vec * beta_val_vec,
780 z >= beta_val_vec);
781 });
782 });
783 }
784 }
785
huber_kernel(TensorIterator & iter,double delta)786 void huber_kernel(TensorIterator& iter, double delta) {
787 AT_DISPATCH_FLOATING_TYPES_AND2(
788 kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() {
789 using Vec = Vectorized<scalar_t>;
790 const scalar_t delta_val(delta);
791 const Vec delta_val_vec(delta_val);
792 const Vec point_five_vec(static_cast<scalar_t>(0.5));
793 cpu_kernel_vec(
794 iter,
795 [&delta_val](scalar_t a, scalar_t b) -> scalar_t {
796 auto z = std::abs(a - b);
797 return z < delta_val
798 ? static_cast<scalar_t>(0.5) * z * z
799 : delta_val * (z - static_cast<scalar_t>(0.5) * delta_val);
800 },
801 [&delta_val_vec, &point_five_vec](Vec a, Vec b) {
802 auto z = (a - b).abs();
803 return Vec::blendv(
804 point_five_vec * z * z,
805 delta_val_vec * (z - point_five_vec * delta_val_vec),
806 z >= delta_val_vec);
807 });
808 });
809 }
810
sigmoid_backward_kernel(TensorIteratorBase & iter)811 void sigmoid_backward_kernel(TensorIteratorBase& iter) {
812 if (isComplexType(iter.dtype())) {
813 AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cpu", [&]() {
814 auto one_vec = Vectorized<scalar_t>(scalar_t{1});
815 cpu_kernel_vec(
816 iter,
817 [=](scalar_t a, scalar_t b) -> scalar_t {
818 return a * std::conj((scalar_t(1) - b) * b);
819 },
820 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
821 return a * ((one_vec - b) * b).conj();
822 });
823 });
824 } else if (iter.dtype() == kBFloat16) {
825 auto one_vec = Vectorized<float>((float)(1));
826 cpu_kernel_vec(
827 iter,
828 [=](BFloat16 a, BFloat16 b) -> BFloat16 {
829 float a0 = static_cast<float>(a);
830 float b0 = static_cast<float>(b);
831 return a0 * (float(1) - b0) * b0;
832 },
833 [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
834 auto [a0, a1] = convert_bfloat16_float(a);
835 auto [b0, b1] = convert_bfloat16_float(b);
836 a0 = a0 * (one_vec - b0) * b0;
837 a1 = a1 * (one_vec - b1) * b1;
838 return convert_float_bfloat16(a0, a1);
839 });
840 } else {
841 AT_DISPATCH_FLOATING_TYPES_AND(
842 kHalf, iter.dtype(), "sigmoid_backward_cpu", [&]() {
843 auto one_vec = Vectorized<scalar_t>((scalar_t)(1));
844 cpu_kernel_vec(
845 iter,
846 [=](scalar_t a, scalar_t b) -> scalar_t {
847 return a * (scalar_t(1) - b) * b;
848 },
849 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
850 return a * (one_vec - b) * b;
851 });
852 });
853 }
854 }
855
logit_backward_kernel(TensorIteratorBase & iter,const Scalar & eps_scalar)856 void logit_backward_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
857 AT_DISPATCH_FLOATING_TYPES_AND2(
858 kBFloat16, kHalf, iter.dtype(), "logit_backward_cpu", [&]() {
859 const scalar_t eps = eps_scalar.to<scalar_t>();
860 const Vectorized<scalar_t> kZeroVec(scalar_t(0));
861 const Vectorized<scalar_t> kOneVec(scalar_t(1));
862 if (eps < scalar_t(0)) {
863 const Vectorized<scalar_t> kNanVec(
864 std::numeric_limits<scalar_t>::quiet_NaN());
865 cpu_kernel_vec(
866 iter,
867 [](scalar_t dy, scalar_t x) {
868 return (x < scalar_t(0) || x > scalar_t(1))
869 ? std::numeric_limits<scalar_t>::quiet_NaN()
870 : ((x == scalar_t(0) || x == scalar_t(1))
871 ? (dy * std::numeric_limits<scalar_t>::infinity())
872 : (dy / (x * (scalar_t(1) - x))));
873 },
874 [kZeroVec, kOneVec, kNanVec](
875 Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
876 return Vectorized<scalar_t>::blendv(
877 kNanVec,
878 dy_vec / (x_vec * (kOneVec - x_vec)),
879 (x_vec >= kZeroVec) & (x_vec <= kOneVec));
880 });
881 } else {
882 const scalar_t lo = eps;
883 const scalar_t hi = scalar_t(1) - eps;
884 const Vectorized<scalar_t> lo_vec(lo);
885 const Vectorized<scalar_t> hi_vec(hi);
886 cpu_kernel_vec(
887 iter,
888 [lo, hi](scalar_t dy, scalar_t x) {
889 return (x < lo || x > hi)
890 ? scalar_t(0)
891 : ((x == scalar_t(0) || x == scalar_t(1))
892 ? dy * std::numeric_limits<scalar_t>::infinity()
893 : dy / (x * (scalar_t(1) - x)));
894 },
895 [kZeroVec, kOneVec, lo_vec, hi_vec](
896 Vectorized<scalar_t> dy_vec, Vectorized<scalar_t> x_vec) {
897 return Vectorized<scalar_t>::blendv(
898 kZeroVec,
899 dy_vec / (x_vec * (kOneVec - x_vec)),
900 (x_vec >= lo_vec) & (x_vec <= hi_vec));
901 });
902 }
903 });
904 }
905
tanh_backward_kernel(TensorIteratorBase & iter)906 void tanh_backward_kernel(TensorIteratorBase& iter) {
907 if (isComplexType(iter.dtype())) {
908 AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
909 auto one_vec = Vectorized<scalar_t>(scalar_t{1});
910 cpu_kernel_vec(
911 iter,
912 [=](scalar_t a, scalar_t b) -> scalar_t {
913 return a * std::conj(scalar_t{1} - b * b);
914 },
915 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
916 return a * (one_vec - b * b).conj();
917 });
918 });
919 } else if (at::isReducedFloatingType(iter.dtype())) {
920 AT_DISPATCH_REDUCED_FLOATING_TYPES(
921 iter.dtype(), "tanh_backward_cpu", [&]() {
922 auto one_vec = Vectorized<float>(float{1});
923 cpu_kernel_vec(
924 iter,
925 [=](scalar_t a, scalar_t b) -> scalar_t {
926 float a0 = float(a);
927 float b0 = float(b);
928 return a0 * (float{1} - b0 * b0);
929 },
930 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
931 auto [a0, a1] = convert_to_float<scalar_t>(a);
932 auto [b0, b1] = convert_to_float<scalar_t>(b);
933 a0 = a0 * (one_vec - b0 * b0);
934 a1 = a1 * (one_vec - b1 * b1);
935 return convert_from_float<scalar_t>(a0, a1);
936 });
937 });
938 } else {
939 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
940 auto one_vec = Vectorized<scalar_t>(scalar_t{1});
941 cpu_kernel_vec(
942 iter,
943 [=](scalar_t a, scalar_t b) -> scalar_t {
944 return a * (scalar_t{1} - b * b);
945 },
946 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
947 return a * (one_vec - b * b);
948 });
949 });
950 }
951 }
952
mse_kernel(TensorIteratorBase & iter)953 void mse_kernel(TensorIteratorBase& iter) {
954 if (iter.dtype() == ScalarType::Half) {
955 TORCH_WARN_ONCE(
956 "Applying the CPU mse kernel on half-type tensors. "
957 "This may be slower than using float or double-type tensors.");
958 }
959
960 AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "mse_cpu", [&]() {
961 cpu_kernel_vec(
962 iter,
963 [=](scalar_t a, scalar_t b) -> scalar_t {
964 auto diff = a - b;
965 return diff * diff;
966 },
967 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
968 auto diff = a - b;
969 return diff * diff;
970 });
971 });
972 }
973
fmod_kernel(TensorIteratorBase & iter)974 void fmod_kernel(TensorIteratorBase& iter) {
975 if (isIntegralType(iter.common_dtype(), /*includeBool=*/false)) {
976 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cpu", [&]() {
977 cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t {
978 TORCH_CHECK(d != 0, "ZeroDivisionError");
979 return x % d;
980 });
981 });
982 } else {
983 AT_DISPATCH_FLOATING_TYPES_AND2(
984 kBFloat16, kHalf, iter.common_dtype(), "fmod_cpu", [&]() {
985 cpu_kernel_vec(
986 iter,
987 [](scalar_t x, scalar_t d) -> scalar_t {
988 return std::fmod(x, d);
989 },
990 [](Vectorized<scalar_t> x, Vectorized<scalar_t> d) {
991 return x.fmod(d);
992 });
993 });
994 }
995 }
996
logaddexp_kernel(TensorIteratorBase & iter)997 void logaddexp_kernel(TensorIteratorBase& iter) {
998 if (at::isReducedFloatingType(iter.dtype())) {
999 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1000 using Vec = Vectorized<scalar_t>;
1001 cpu_kernel_vec(
1002 iter,
1003 [=](scalar_t a, scalar_t b) -> scalar_t {
1004 float a0 = static_cast<float>(a);
1005 float b0 = static_cast<float>(b);
1006 if (std::isinf(a0) && a0 == b0) {
1007 return a0;
1008 } else {
1009 float m0 = std::max(a0, b0);
1010 return m0 + std::log1p(std::exp(-std::abs(a0 - b0)));
1011 }
1012 },
1013 [=](Vec a, Vec b) -> Vec {
1014 auto [a0, a1] = convert_to_float<scalar_t>(a);
1015 auto [b0, b1] = convert_to_float<scalar_t>(b);
1016 Vectorized<float> inf(std::numeric_limits<float>::infinity());
1017 Vectorized<float> m0 = maximum(a0, b0);
1018 Vectorized<float> m1 = maximum(a1, b1);
1019 a0 = Vectorized<float>::blendv(
1020 m0 + (a0 - b0).abs().neg().exp().log1p(),
1021 a0,
1022 (a0 == b0) & (a0.abs() == inf));
1023 a1 = Vectorized<float>::blendv(
1024 m1 + (a1 - b1).abs().neg().exp().log1p(),
1025 a1,
1026 (a1 == b1) & (a1.abs() == inf));
1027 return convert_from_float<scalar_t>(a0, a1);
1028 });
1029 });
1030 } else if (isComplexType(iter.dtype())) {
1031 AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1032 cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
1033 return _log_add_exp_helper(a, b);
1034 });
1035 });
1036 } else {
1037 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp_cpu", [&]() {
1038 cpu_kernel_vec(
1039 iter,
1040 [=](scalar_t a, scalar_t b) -> scalar_t {
1041 if (std::isinf(a) && a == b) {
1042 return a;
1043 } else {
1044 scalar_t m = std::max(a, b);
1045 return m + std::log1p(std::exp(-std::abs(a - b)));
1046 }
1047 },
1048 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1049 Vectorized<scalar_t> inf(std::numeric_limits<scalar_t>::infinity());
1050 Vectorized<scalar_t> m = maximum(a, b);
1051 return Vectorized<scalar_t>::blendv(
1052 m + (a - b).abs().neg().exp().log1p(),
1053 a,
1054 (a == b) & (a.abs() == inf));
1055 });
1056 });
1057 }
1058 }
1059
logaddexp2_kernel(TensorIteratorBase & iter)1060 void logaddexp2_kernel(TensorIteratorBase& iter) {
1061 if (at::isReducedFloatingType(iter.dtype())) {
1062 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "logaddexp2_cpu", [&]() {
1063 using Vec = Vectorized<scalar_t>;
1064 constexpr auto inv_log_2 = static_cast<float>(1.0 / c10::ln_2<double>);
1065 cpu_kernel_vec(
1066 iter,
1067 [=](scalar_t a, scalar_t b) -> scalar_t {
1068 float a0 = static_cast<float>(a);
1069 float b0 = static_cast<float>(b);
1070 if (std::isinf(a0) && a0 == b0) {
1071 return a0;
1072 } else {
1073 float m0 = std::max(a0, b0);
1074 return m0 + std::log1p(std::exp2(-std::abs(a0 - b0))) * inv_log_2;
1075 }
1076 },
1077 [=](Vec a, Vec b) -> Vec {
1078 auto [a0, a1] = convert_to_float<scalar_t>(a);
1079 auto [b0, b1] = convert_to_float<scalar_t>(b);
1080 Vectorized<float> inf(std::numeric_limits<float>::infinity());
1081 Vectorized<float> inv_log_2_vec(inv_log_2);
1082 Vectorized<float> m0 = maximum(a0, b0);
1083 Vectorized<float> m1 = maximum(a1, b1);
1084 a0 = Vectorized<float>::blendv(
1085 m0 + (a0 - b0).abs().neg().exp2().log1p() * inv_log_2_vec,
1086 a0,
1087 (a0 == b0) & (a0.abs() == inf));
1088 a1 = Vectorized<float>::blendv(
1089 m1 + (a1 - b1).abs().neg().exp2().log1p() * inv_log_2_vec,
1090 a1,
1091 (a1 == b1) & (a1.abs() == inf));
1092 return convert_from_float<scalar_t>(a0, a1);
1093 });
1094 });
1095 } else {
1096 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp2_cpu", [&]() {
1097 constexpr auto inv_log_2 = static_cast<scalar_t>(1.0 / c10::ln_2<double>);
1098 cpu_kernel_vec(
1099 iter,
1100 [=](scalar_t a, scalar_t b) -> scalar_t {
1101 if (std::isinf(a) && a == b) {
1102 return a;
1103 } else {
1104 scalar_t m = std::max(a, b);
1105 return m + std::log1p(std::exp2(-std::abs(a - b))) * inv_log_2;
1106 }
1107 },
1108 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1109 Vectorized<scalar_t> inf(std::numeric_limits<scalar_t>::infinity());
1110 Vectorized<scalar_t> inv_log_2_vec(inv_log_2);
1111 Vectorized<scalar_t> m = maximum(a, b);
1112 return Vectorized<scalar_t>::blendv(
1113 m + (a - b).abs().neg().exp2().log1p() * inv_log_2_vec,
1114 a,
1115 (a == b) & (a.abs() == inf));
1116 });
1117 });
1118 }
1119 }
1120
gcd_kernel(TensorIteratorBase & iter)1121 void gcd_kernel(TensorIteratorBase& iter) {
1122 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cpu", [&]() {
1123 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1124 return calc_gcd(a, b);
1125 });
1126 });
1127 }
1128
lcm_kernel(TensorIteratorBase & iter)1129 void lcm_kernel(TensorIteratorBase& iter) {
1130 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cpu", [&]() {
1131 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1132 scalar_t g = calc_gcd(a, b);
1133 return (g == 0) ? 0 : std::abs(a / g * b);
1134 });
1135 });
1136 }
1137
hypot_kernel(TensorIteratorBase & iter)1138 void hypot_kernel(TensorIteratorBase& iter) {
1139 AT_DISPATCH_FLOATING_TYPES_AND2(
1140 kBFloat16, kHalf, iter.dtype(), "hypot_cpu", [&]() {
1141 cpu_kernel_vec(
1142 iter,
1143 [=](scalar_t a, scalar_t b) -> scalar_t {
1144 return std::hypot(a, b);
1145 },
1146 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1147 return a.hypot(b);
1148 });
1149 });
1150 }
1151
igamma_kernel(TensorIteratorBase & iter)1152 void igamma_kernel(TensorIteratorBase& iter) {
1153 AT_DISPATCH_FLOATING_TYPES_AND2(
1154 kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() {
1155 cpu_kernel_vec(
1156 iter,
1157 [=](scalar_t a, scalar_t b) -> scalar_t {
1158 return calc_igamma(a, b);
1159 },
1160 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1161 return a.igamma(b);
1162 });
1163 });
1164 }
1165
igammac_kernel(TensorIteratorBase & iter)1166 void igammac_kernel(TensorIteratorBase& iter) {
1167 AT_DISPATCH_FLOATING_TYPES_AND2(
1168 kHalf, kBFloat16, iter.dtype(), "igammac_cpu", [&]() {
1169 cpu_kernel_vec(
1170 iter,
1171 [=](scalar_t a, scalar_t b) -> scalar_t {
1172 return calc_igammac(a, b);
1173 },
1174 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1175 return a.igammac(b);
1176 });
1177 });
1178 }
1179
nextafter_kernel(TensorIteratorBase & iter)1180 void nextafter_kernel(TensorIteratorBase& iter) {
1181 if (at::isReducedFloatingType(iter.common_dtype())) {
1182 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() {
1183 cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
1184 return std::nextafter(a, b);
1185 });
1186 });
1187 } else {
1188 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cpu", [&]() {
1189 cpu_kernel_vec(
1190 iter,
1191 [=](scalar_t a, scalar_t b) -> scalar_t {
1192 return std::nextafter(a, b);
1193 },
1194 [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
1195 return a.nextafter(b);
1196 });
1197 });
1198 }
1199 }
1200
heaviside_kernel(TensorIteratorBase & iter)1201 void heaviside_kernel(TensorIteratorBase& iter) {
1202 AT_DISPATCH_ALL_TYPES_AND3(
1203 kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
1204 cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
1205 return a == 0 ? b : static_cast<scalar_t>(a > 0);
1206 });
1207 });
1208 }
1209
copysign_kernel(TensorIteratorBase & iter)1210 void copysign_kernel(TensorIteratorBase& iter) {
1211 AT_DISPATCH_FLOATING_TYPES_AND2(
1212 kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
1213 cpu_kernel_vec(
1214 iter,
1215 [](scalar_t a, scalar_t b) -> scalar_t {
1216 return c10::copysign(a, b);
1217 },
1218 [](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
1219 -> Vectorized<scalar_t> { return a.copysign(b); });
1220 });
1221 }
1222
xlogy_kernel(TensorIteratorBase & iter)1223 void xlogy_kernel(TensorIteratorBase& iter) {
1224 AT_DISPATCH_FLOATING_TYPES_AND2(
1225 kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
1226 cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
1227 if (at::_isnan(y)) {
1228 return NAN;
1229 }
1230 if (x == 0) {
1231 return 0;
1232 }
1233 return x * std::log(y);
1234 });
1235 });
1236 }
1237
xlog1py_kernel(TensorIteratorBase & iter)1238 void xlog1py_kernel(TensorIteratorBase& iter) {
1239 AT_DISPATCH_FLOATING_TYPES_AND2(
1240 kBFloat16, kHalf, iter.common_dtype(), "xlog1py_cpu", [&]() {
1241 cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
1242 if (at::_isnan(y)) {
1243 return NAN;
1244 }
1245 if (x == 0) {
1246 return 0;
1247 }
1248 return x * std::log1p(y);
1249 });
1250 });
1251 }
1252
zeta_kernel(TensorIteratorBase & iter)1253 void zeta_kernel(TensorIteratorBase& iter) {
1254 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cpu", [&]() {
1255 cpu_kernel(
1256 iter, [](scalar_t x, scalar_t q) -> scalar_t { return zeta(x, q); });
1257 });
1258 }
1259
chebyshev_polynomial_t_kernel(TensorIteratorBase & iterator)1260 void chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator) {
1261 AT_DISPATCH_FLOATING_TYPES(
1262 iterator.common_dtype(), "chebyshev_polynomial_t_cpu", [&]() {
1263 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1264 return chebyshev_polynomial_t_forward(x, n);
1265 });
1266 });
1267 } // chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator)
1268
chebyshev_polynomial_u_kernel(TensorIteratorBase & iterator)1269 void chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator) {
1270 AT_DISPATCH_FLOATING_TYPES(
1271 iterator.common_dtype(), "chebyshev_polynomial_u_cpu", [&]() {
1272 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1273 return chebyshev_polynomial_u_forward(x, n);
1274 });
1275 });
1276 } // chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator)
1277
chebyshev_polynomial_v_kernel(TensorIteratorBase & iterator)1278 void chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator) {
1279 AT_DISPATCH_FLOATING_TYPES(
1280 iterator.common_dtype(), "chebyshev_polynomial_v_cpu", [&]() {
1281 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1282 return chebyshev_polynomial_v_forward(x, n);
1283 });
1284 });
1285 } // chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator)
1286
chebyshev_polynomial_w_kernel(TensorIteratorBase & iterator)1287 void chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator) {
1288 AT_DISPATCH_FLOATING_TYPES(
1289 iterator.common_dtype(), "chebyshev_polynomial_w_cpu", [&]() {
1290 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1291 return chebyshev_polynomial_w_forward(x, n);
1292 });
1293 });
1294 } // chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator)
1295
hermite_polynomial_h_kernel(TensorIteratorBase & iterator)1296 void hermite_polynomial_h_kernel(TensorIteratorBase& iterator) {
1297 AT_DISPATCH_FLOATING_TYPES(
1298 iterator.common_dtype(), "hermite_polynomial_h_cpu", [&]() {
1299 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1300 return hermite_polynomial_h_forward(x, n);
1301 });
1302 });
1303 } // hermite_polynomial_h_kernel(TensorIteratorBase& iterator)
1304
hermite_polynomial_he_kernel(TensorIteratorBase & iterator)1305 void hermite_polynomial_he_kernel(TensorIteratorBase& iterator) {
1306 AT_DISPATCH_FLOATING_TYPES(
1307 iterator.common_dtype(), "hermite_polynomial_he_cpu", [&]() {
1308 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1309 return hermite_polynomial_he_forward(x, n);
1310 });
1311 });
1312 } // hermite_polynomial_he_kernel(TensorIteratorBase& iterator)
1313
laguerre_polynomial_l_kernel(TensorIteratorBase & iterator)1314 void laguerre_polynomial_l_kernel(TensorIteratorBase& iterator) {
1315 AT_DISPATCH_FLOATING_TYPES(
1316 iterator.common_dtype(), "laguerre_polynomial_l_cpu", [&]() {
1317 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1318 return laguerre_polynomial_l_forward(x, n);
1319 });
1320 });
1321 } // laguerre_polynomial_l_kernel(TensorIteratorBase& iterator)
1322
legendre_polynomial_p_kernel(TensorIteratorBase & iterator)1323 void legendre_polynomial_p_kernel(TensorIteratorBase& iterator) {
1324 AT_DISPATCH_FLOATING_TYPES(
1325 iterator.common_dtype(), "legendre_polynomial_p_cpu", [&]() {
1326 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1327 return legendre_polynomial_p_forward(x, n);
1328 });
1329 });
1330 } // legendre_polynomial_p_kernel(TensorIteratorBase& iterator)
1331
shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase & iterator)1332 void shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator) {
1333 AT_DISPATCH_FLOATING_TYPES(
1334 iterator.common_dtype(), "shifted_chebyshev_polynomial_t_cpu", [&]() {
1335 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1336 return shifted_chebyshev_polynomial_t_forward(x, n);
1337 });
1338 });
1339 } // shifted_chebyshev_polynomial_t_kernel(TensorIteratorBase& iterator)
1340
shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase & iterator)1341 void shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator) {
1342 AT_DISPATCH_FLOATING_TYPES(
1343 iterator.common_dtype(), "shifted_chebyshev_polynomial_u_cpu", [&]() {
1344 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1345 return shifted_chebyshev_polynomial_u_forward(x, n);
1346 });
1347 });
1348 } // shifted_chebyshev_polynomial_u_kernel(TensorIteratorBase& iterator)
1349
shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase & iterator)1350 void shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator) {
1351 AT_DISPATCH_FLOATING_TYPES(
1352 iterator.common_dtype(), "shifted_chebyshev_polynomial_v_cpu", [&]() {
1353 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1354 return shifted_chebyshev_polynomial_v_forward(x, n);
1355 });
1356 });
1357 } // shifted_chebyshev_polynomial_v_kernel(TensorIteratorBase& iterator)
1358
shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase & iterator)1359 void shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator) {
1360 AT_DISPATCH_FLOATING_TYPES(
1361 iterator.common_dtype(), "shifted_chebyshev_polynomial_w_cpu", [&]() {
1362 cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1363 return shifted_chebyshev_polynomial_w_forward(x, n);
1364 });
1365 });
1366 } // shifted_chebyshev_polynomial_w_kernel(TensorIteratorBase& iterator)
1367
1368 } // namespace
1369
1370 REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
1371 REGISTER_DISPATCH(mul_stub, &mul_kernel);
1372 REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
1373 REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
1374 REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel);
1375 REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
1376 REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel);
1377 REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
1378 REGISTER_DISPATCH(lshift_stub, &lshift_kernel);
1379 REGISTER_DISPATCH(rshift_stub, &rshift_kernel);
1380 REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
1381 REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);
1382 REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel);
1383 REGISTER_DISPATCH(lt_stub, <_kernel);
1384 REGISTER_DISPATCH(le_stub, &le_kernel);
1385 REGISTER_DISPATCH(gt_stub, >_kernel);
1386 REGISTER_DISPATCH(ge_stub, &ge_kernel);
1387 REGISTER_DISPATCH(eq_stub, &eq_kernel);
1388 REGISTER_DISPATCH(ne_stub, &ne_kernel);
1389 REGISTER_DISPATCH(maximum_stub, &maximum_kernel);
1390 REGISTER_DISPATCH(minimum_stub, &minimum_kernel);
1391 REGISTER_DISPATCH(fmax_stub, &fmax_kernel);
1392 REGISTER_DISPATCH(fmin_stub, &fmin_kernel);
1393 REGISTER_DISPATCH(copysign_stub, ©sign_kernel);
1394 REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
1395 REGISTER_DISPATCH(fmod_stub, &fmod_kernel);
1396 REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
1397 REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
1398 REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
1399 REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
1400 REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
1401 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
1402 REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
1403 REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_kernel);
1404 REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_kernel);
1405 REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_kernel);
1406 REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel);
1407 REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel);
1408 REGISTER_DISPATCH(
1409 shifted_chebyshev_polynomial_t_stub,
1410 &shifted_chebyshev_polynomial_t_kernel);
1411 REGISTER_DISPATCH(
1412 shifted_chebyshev_polynomial_u_stub,
1413 &shifted_chebyshev_polynomial_u_kernel);
1414 REGISTER_DISPATCH(
1415 shifted_chebyshev_polynomial_v_stub,
1416 &shifted_chebyshev_polynomial_v_kernel);
1417 REGISTER_DISPATCH(
1418 shifted_chebyshev_polynomial_w_stub,
1419 &shifted_chebyshev_polynomial_w_kernel);
1420 // Might enable AVX512 dispatch after enabling explicit vectorization for them.
1421 REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel);
1422 REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel);
1423 REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel);
1424
1425 ALSO_REGISTER_AVX512_DISPATCH(atan2_stub, &atan2_kernel);
1426 ALSO_REGISTER_AVX512_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
1427 ALSO_REGISTER_AVX512_DISPATCH(huber_stub, &huber_kernel);
1428 ALSO_REGISTER_AVX512_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
1429 ALSO_REGISTER_AVX512_DISPATCH(logit_backward_stub, &logit_backward_kernel);
1430 ALSO_REGISTER_AVX512_DISPATCH(tanh_backward_stub, &tanh_backward_kernel);
1431 ALSO_REGISTER_AVX512_DISPATCH(mse_stub, &mse_kernel);
1432 ALSO_REGISTER_AVX512_DISPATCH(logaddexp_stub, &logaddexp_kernel);
1433 ALSO_REGISTER_AVX512_DISPATCH(logaddexp2_stub, &logaddexp2_kernel);
1434 ALSO_REGISTER_AVX512_DISPATCH(hypot_stub, &hypot_kernel);
1435 ALSO_REGISTER_AVX512_DISPATCH(igamma_stub, &igamma_kernel);
1436 ALSO_REGISTER_AVX512_DISPATCH(igammac_stub, &igammac_kernel);
1437
1438 } // namespace at::native
1439