xref: /aosp_15_r20/external/pytorch/caffe2/utils/fixed_divisor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef CAFFE2_UTILS_FIXED_DIVISOR_H_
2 #define CAFFE2_UTILS_FIXED_DIVISOR_H_
3 
4 #include <cstdint>
5 #include <cstdio>
6 #include <cstdlib>
7 
8 // See Note [hip-clang differences to hcc]
9 
10 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__) || \
11     (defined(__clang__) && defined(__CUDA__))
12 #define FIXED_DIVISOR_DECL inline __host__ __device__
13 #else
14 #define FIXED_DIVISOR_DECL inline
15 #endif
16 
17 namespace caffe2 {
18 
19 // Utility class for quickly calculating quotients and remainders for
20 // a known integer divisor
21 template <typename T>
22 class FixedDivisor {};
23 
24 // Works for any positive divisor, 1 to INT_MAX. One 64-bit
25 // multiplication and one 64-bit shift is used to calculate the
26 // result.
27 template <>
28 class FixedDivisor<std::int32_t> {
29  public:
30   FixedDivisor() = default;
31 
FixedDivisor(const std::int32_t d)32   explicit FixedDivisor(const std::int32_t d) : d_(d) {
33 #if !defined(USE_ROCM)
34     CalcSignedMagic();
35 #endif // USE_ROCM
36   }
37 
d()38   FIXED_DIVISOR_DECL std::int32_t d() const {
39     return d_;
40   }
41 
42 #if !defined(USE_ROCM)
magic()43   FIXED_DIVISOR_DECL std::uint64_t magic() const {
44     return magic_;
45   }
46 
shift()47   FIXED_DIVISOR_DECL int shift() const {
48     return shift_;
49   }
50 #endif // USE_ROCM
51 
52   /// Calculates `q = n / d`.
Div(const std::int32_t n)53   FIXED_DIVISOR_DECL std::int32_t Div(const std::int32_t n) const {
54 #if defined(USE_ROCM)
55     return n / d_;
56 #else // USE_ROCM
57     // In lieu of a mulhi instruction being available, perform the
58     // work in uint64
59     return (int32_t)((magic_ * (uint64_t)n) >> shift_);
60 #endif // USE_ROCM
61   }
62 
63   /// Calculates `r = n % d`.
Mod(const std::int32_t n)64   FIXED_DIVISOR_DECL std::int32_t Mod(const std::int32_t n) const {
65     return n - d_ * Div(n);
66   }
67 
68   /// Calculates `q = n / d` and `r = n % d` together.
69   FIXED_DIVISOR_DECL void
DivMod(const std::int32_t n,std::int32_t * q,int32_t * r)70   DivMod(const std::int32_t n, std::int32_t* q, int32_t* r) const {
71     *q = Div(n);
72     *r = n - d_ * *q;
73   }
74 
75  private:
76 #if !defined(USE_ROCM)
77   // Calculates magic multiplicative value and shift amount for calculating `q =
78   // n / d` for signed 32-bit integers.
79   // Implementation taken from Hacker's Delight section 10.
CalcSignedMagic()80   void CalcSignedMagic() {
81     if (d_ == 1) {
82       magic_ = UINT64_C(0x1) << 32;
83       shift_ = 32;
84       return;
85     }
86 
87     const std::uint32_t two31 = UINT32_C(0x80000000);
88     const std::uint32_t ad = std::abs(d_);
89     const std::uint32_t t = two31 + ((uint32_t)d_ >> 31);
90     const std::uint32_t anc = t - 1 - t % ad; // Absolute value of nc.
91     std::uint32_t p = 31; // Init. p.
92     std::uint32_t q1 = two31 / anc; // Init. q1 = 2**p/|nc|.
93     std::uint32_t r1 = two31 - q1 * anc; // Init. r1 = rem(2**p, |nc|).
94     std::uint32_t q2 = two31 / ad; // Init. q2 = 2**p/|d|.
95     std::uint32_t r2 = two31 - q2 * ad; // Init. r2 = rem(2**p, |d|).
96     std::uint32_t delta = 0;
97     do {
98       ++p;
99       q1 <<= 1; // Update q1 = 2**p/|nc|.
100       r1 <<= 1; // Update r1 = rem(2**p, |nc|).
101       if (r1 >= anc) { // (Must be an unsigned
102         ++q1; // comparison here).
103         r1 -= anc;
104       }
105       q2 <<= 1; // Update q2 = 2**p/|d|.
106       r2 <<= 1; // Update r2 = rem(2**p, |d|).
107       if (r2 >= ad) { // (Must be an unsigned
108         ++q2; // comparison here).
109         r2 -= ad;
110       }
111       delta = ad - r2;
112     } while (q1 < delta || (q1 == delta && r1 == 0));
113     std::int32_t magic = q2 + 1;
114     if (d_ < 0) {
115       magic = -magic;
116     }
117     shift_ = p;
118     magic_ = (std::uint64_t)(std::uint32_t)magic;
119   }
120 #endif // USE_ROCM
121 
122   std::int32_t d_ = 1;
123 
124 #if !defined(USE_ROCM)
125   std::uint64_t magic_;
126   int shift_;
127 #endif // USE_ROCM
128 };
129 
130 } // namespace caffe2
131 
132 #endif // CAFFE2_UTILS_FIXED_DIVISOR_H_
133