1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeSafeSignMath.h>
5*da0073e9SAndroid Build Coastguard Worker #include <cmath>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__)
8*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMathCompat.h>
9*da0073e9SAndroid Build Coastguard Worker #define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
10*da0073e9SAndroid Build Coastguard Worker #elif defined(__HIPCC__)
11*da0073e9SAndroid Build Coastguard Worker #include <c10/hip/HIPMathCompat.h>
12*da0073e9SAndroid Build Coastguard Worker #define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
13*da0073e9SAndroid Build Coastguard Worker #else
14*da0073e9SAndroid Build Coastguard Worker #include <c10/util/copysign.h>
15*da0073e9SAndroid Build Coastguard Worker #define C10_COMPAT_COPYSIGN c10::copysign
16*da0073e9SAndroid Build Coastguard Worker #endif
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker // The functions in this file should be header-only as it is used under
19*da0073e9SAndroid Build Coastguard Worker // ABI-compatibility mode.
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker namespace c10 {
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker // NOTE: [Floor Division in Python]
24*da0073e9SAndroid Build Coastguard Worker // Python's __floordiv__ operator is more complicated than just floor(a / b).
25*da0073e9SAndroid Build Coastguard Worker // It aims to maintain the property: a == (a // b) * b + remainder(a, b)
26*da0073e9SAndroid Build Coastguard Worker // which can otherwise fail due to rounding errors in the remainder.
27*da0073e9SAndroid Build Coastguard Worker // So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
28*da0073e9SAndroid Build Coastguard Worker // With some additional fix-ups added to the result.
29*da0073e9SAndroid Build Coastguard Worker //
30*da0073e9SAndroid Build Coastguard Worker // For reference, see CPython's implementation:
31*da0073e9SAndroid Build Coastguard Worker // https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker template <typename scalar_t>
div_floor_floating(scalar_t a,scalar_t b)34*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b)
35*da0073e9SAndroid Build Coastguard Worker __ubsan_ignore_float_divide_by_zero__ {
36*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(b == 0)) {
37*da0073e9SAndroid Build Coastguard Worker // Divide by zero: return standard IEEE result
38*da0073e9SAndroid Build Coastguard Worker return a / b;
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker auto mod = std::fmod(a, b);
42*da0073e9SAndroid Build Coastguard Worker auto div = (a - mod) / b;
43*da0073e9SAndroid Build Coastguard Worker if ((mod != 0) && (b < 0) != (mod < 0)) {
44*da0073e9SAndroid Build Coastguard Worker div -= scalar_t(1);
45*da0073e9SAndroid Build Coastguard Worker }
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker scalar_t floordiv;
48*da0073e9SAndroid Build Coastguard Worker if (div != 0) {
49*da0073e9SAndroid Build Coastguard Worker floordiv = std::floor(div);
50*da0073e9SAndroid Build Coastguard Worker if (div - floordiv > scalar_t(0.5)) {
51*da0073e9SAndroid Build Coastguard Worker floordiv += scalar_t(1.0);
52*da0073e9SAndroid Build Coastguard Worker }
53*da0073e9SAndroid Build Coastguard Worker } else {
54*da0073e9SAndroid Build Coastguard Worker floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b);
55*da0073e9SAndroid Build Coastguard Worker }
56*da0073e9SAndroid Build Coastguard Worker return floordiv;
57*da0073e9SAndroid Build Coastguard Worker }
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker template <typename scalar_t>
div_floor_integer(scalar_t a,scalar_t b)60*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
61*da0073e9SAndroid Build Coastguard Worker if (c10::signs_differ(a, b)) {
62*da0073e9SAndroid Build Coastguard Worker // Subtracts one from the results of truncation division if the
63*da0073e9SAndroid Build Coastguard Worker // divisor and dividend have different sign(bit)s and the remainder of
64*da0073e9SAndroid Build Coastguard Worker // the division is nonzero
65*da0073e9SAndroid Build Coastguard Worker const auto quot = a / b;
66*da0073e9SAndroid Build Coastguard Worker const auto rem = a % b;
67*da0073e9SAndroid Build Coastguard Worker return rem ? quot - 1 : quot;
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker return a / b;
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker } // namespace c10
73