1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <c10/util/TypeSafeSignMath.h>
5 #include <cmath>
6
7 #if defined(__CUDA_ARCH__)
8 #include <c10/cuda/CUDAMathCompat.h>
9 #define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
10 #elif defined(__HIPCC__)
11 #include <c10/hip/HIPMathCompat.h>
12 #define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
13 #else
14 #include <c10/util/copysign.h>
15 #define C10_COMPAT_COPYSIGN c10::copysign
16 #endif
17
18 // The functions in this file should be header-only as it is used under
19 // ABI-compatibility mode.
20
21 namespace c10 {
22
23 // NOTE: [Floor Division in Python]
24 // Python's __floordiv__ operator is more complicated than just floor(a / b).
25 // It aims to maintain the property: a == (a // b) * b + remainder(a, b)
26 // which can otherwise fail due to rounding errors in the remainder.
27 // So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
28 // With some additional fix-ups added to the result.
29 //
30 // For reference, see CPython's implementation:
31 // https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
32
33 template <typename scalar_t>
div_floor_floating(scalar_t a,scalar_t b)34 inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b)
35 __ubsan_ignore_float_divide_by_zero__ {
36 if (C10_UNLIKELY(b == 0)) {
37 // Divide by zero: return standard IEEE result
38 return a / b;
39 }
40
41 auto mod = std::fmod(a, b);
42 auto div = (a - mod) / b;
43 if ((mod != 0) && (b < 0) != (mod < 0)) {
44 div -= scalar_t(1);
45 }
46
47 scalar_t floordiv;
48 if (div != 0) {
49 floordiv = std::floor(div);
50 if (div - floordiv > scalar_t(0.5)) {
51 floordiv += scalar_t(1.0);
52 }
53 } else {
54 floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b);
55 }
56 return floordiv;
57 }
58
59 template <typename scalar_t>
div_floor_integer(scalar_t a,scalar_t b)60 inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
61 if (c10::signs_differ(a, b)) {
62 // Subtracts one from the results of truncation division if the
63 // divisor and dividend have different sign(bit)s and the remainder of
64 // the division is nonzero
65 const auto quot = a / b;
66 const auto rem = a % b;
67 return rem ? quot - 1 : quot;
68 }
69 return a / b;
70 }
71
72 } // namespace c10
73