xref: /aosp_15_r20/external/pytorch/c10/util/int128.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This file is based on the uint128 implementation of protobuf at
2 // https://github.com/protocolbuffers/protobuf/blob/1e88936fce10cf773cb72b44c6a7f48b38c7578b/src/google/protobuf/stubs/int128.h
3 //
4 // Protocol Buffers - Google's data interchange format
5 // Copyright 2008 Google Inc.  All rights reserved.
6 // https://developers.google.com/protocol-buffers/
7 //
8 // Redistribution and use in source and binary forms, with or without
9 // modification, are permitted provided that the following conditions are
10 // met:
11 //
12 //     * Redistributions of source code must retain the above copyright
13 // notice, this list of conditions and the following disclaimer.
14 //     * Redistributions in binary form must reproduce the above
15 // copyright notice, this list of conditions and the following disclaimer
16 // in the documentation and/or other materials provided with the
17 // distribution.
18 //     * Neither the name of Google Inc. nor the names of its
19 // contributors may be used to endorse or promote products derived from
20 // this software without specific prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 #pragma once
34 
35 #include <c10/macros/Export.h>
36 #include <cstdint>
37 #include <iosfwd>
38 
39 namespace c10 {
40 
41 struct uint128_pod;
42 
43 // TODO(xiaofeng): Define GOOGLE_PROTOBUF_HAS_CONSTEXPR when constexpr is
44 // available.
45 #ifdef GOOGLE_PROTOBUF_HAS_CONSTEXPR
46 #define UINT128_CONSTEXPR constexpr
47 #else
48 #define UINT128_CONSTEXPR
49 #endif
50 
51 class uint128;
52 inline uint128& operator<<=(uint128& self, int amount);
53 
54 // An unsigned 128-bit integer type. Thread-compatible.
55 class C10_API uint128 {
56  public:
57   UINT128_CONSTEXPR uint128(); // Sets to 0, but don't trust on this behavior.
58   UINT128_CONSTEXPR uint128(uint64_t top, uint64_t bottom);
59 #ifndef SWIG
60   UINT128_CONSTEXPR uint128(int bottom);
61   UINT128_CONSTEXPR uint128(uint32_t bottom); // Top 96 bits = 0
62 #endif
63   UINT128_CONSTEXPR uint128(uint64_t bottom); // hi_ = 0
64   UINT128_CONSTEXPR uint128(const uint128_pod& val);
65 
66   // Trivial copy constructor, assignment operator and destructor.
67 
68   void Initialize(uint64_t top, uint64_t bottom);
69 
70   // Arithmetic operators.
71   uint128& operator+=(const uint128& b);
72   uint128& operator-=(const uint128& b);
73   uint128& operator*=(const uint128& b);
74   // Long division/modulo for uint128.
75   uint128& operator/=(const uint128& b);
76   uint128& operator%=(const uint128& b);
77   uint128 operator++(int);
78   uint128 operator--(int);
79   // Make msvc happy with using operator<<= from DivModImpl
80   // which is a static function, and linker complained about missing
81   // static version of this overload
82   friend uint128& operator<<=(uint128&, int);
83   uint128& operator>>=(int);
84   uint128& operator&=(const uint128& b);
85   uint128& operator|=(const uint128& b);
86   uint128& operator^=(const uint128& b);
87   uint128& operator++();
88   uint128& operator--();
89 
90   friend uint64_t Uint128Low64(const uint128& v);
91   friend uint64_t Uint128High64(const uint128& v);
92 
93   // We add "std::" to avoid including all of port.h.
94   C10_API friend std::ostream& operator<<(std::ostream& o, const uint128& b);
95 
96  private:
97   static void DivModImpl(
98       uint128 dividend,
99       uint128 divisor,
100       uint128* quotient_ret,
101       uint128* remainder_ret);
102 
103   // Little-endian memory order optimizations can benefit from
104   // having lo_ first, hi_ last.
105   // See util/endian/endian.h and Load128/Store128 for storing a uint128.
106   uint64_t lo_;
107   uint64_t hi_;
108 
109   // Not implemented, just declared for catching automatic type conversions.
110   uint128(uint8_t);
111   uint128(uint16_t);
112   uint128(float v);
113   uint128(double v);
114 };
115 
116 // This is a POD form of uint128 which can be used for static variables which
117 // need to be operated on as uint128.
118 struct uint128_pod {
119   // Note: The ordering of fields is different than 'class uint128' but the
120   // same as its 2-arg constructor.  This enables more obvious initialization
121   // of static instances, which is the primary reason for this struct in the
122   // first place.  This does not seem to defeat any optimizations wrt
123   // operations involving this struct.
124   uint64_t hi;
125   uint64_t lo;
126 };
127 
128 C10_API extern const uint128_pod kuint128max;
129 
130 // allow uint128 to be logged
131 C10_API extern std::ostream& operator<<(std::ostream& o, const uint128& b);
132 
133 // Methods to access low and high pieces of 128-bit value.
134 // Defined externally from uint128 to facilitate conversion
135 // to native 128-bit types when compilers support them.
Uint128Low64(const uint128 & v)136 inline uint64_t Uint128Low64(const uint128& v) {
137   return v.lo_;
138 }
Uint128High64(const uint128 & v)139 inline uint64_t Uint128High64(const uint128& v) {
140   return v.hi_;
141 }
142 
143 // TODO: perhaps it would be nice to have int128, a signed 128-bit type?
144 
145 // --------------------------------------------------------------------------
146 //                      Implementation details follow
147 // --------------------------------------------------------------------------
148 inline bool operator==(const uint128& lhs, const uint128& rhs) {
149   return (
150       Uint128Low64(lhs) == Uint128Low64(rhs) &&
151       Uint128High64(lhs) == Uint128High64(rhs));
152 }
153 inline bool operator!=(const uint128& lhs, const uint128& rhs) {
154   return !(lhs == rhs);
155 }
156 
uint128()157 C10_API inline UINT128_CONSTEXPR uint128::uint128() : lo_(0), hi_(0) {}
uint128(uint64_t top,uint64_t bottom)158 C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t top, uint64_t bottom)
159     : lo_(bottom), hi_(top) {}
uint128(const uint128_pod & v)160 C10_API inline UINT128_CONSTEXPR uint128::uint128(const uint128_pod& v)
161     : lo_(v.lo), hi_(v.hi) {}
uint128(uint64_t bottom)162 C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t bottom)
163     : lo_(bottom), hi_(0) {}
164 #ifndef SWIG
uint128(uint32_t bottom)165 C10_API inline UINT128_CONSTEXPR uint128::uint128(uint32_t bottom)
166     : lo_(bottom), hi_(0) {}
uint128(int bottom)167 C10_API inline UINT128_CONSTEXPR uint128::uint128(int bottom)
168     : lo_(bottom), hi_(static_cast<int64_t>((bottom < 0) ? -1 : 0)) {}
169 #endif
170 
171 #undef UINT128_CONSTEXPR
172 
Initialize(uint64_t top,uint64_t bottom)173 C10_API inline void uint128::Initialize(uint64_t top, uint64_t bottom) {
174   hi_ = top;
175   lo_ = bottom;
176 }
177 
178 // Comparison operators.
179 
180 #define CMP128(op)                                                  \
181   inline bool operator op(const uint128& lhs, const uint128& rhs) { \
182     return (Uint128High64(lhs) == Uint128High64(rhs))               \
183         ? (Uint128Low64(lhs) op Uint128Low64(rhs))                  \
184         : (Uint128High64(lhs) op Uint128High64(rhs));               \
185   }
186 
187 CMP128(<)
188 CMP128(>)
189 CMP128(>=)
190 CMP128(<=)
191 
192 #undef CMP128
193 
194 // Unary operators
195 
196 inline uint128 operator-(const uint128& val) {
197   const uint64_t hi_flip = ~Uint128High64(val);
198   const uint64_t lo_flip = ~Uint128Low64(val);
199   const uint64_t lo_add = lo_flip + 1;
200   if (lo_add < lo_flip) {
201     return uint128(hi_flip + 1, lo_add);
202   }
203   return uint128(hi_flip, lo_add);
204 }
205 
206 inline bool operator!(const uint128& val) {
207   return !Uint128High64(val) && !Uint128Low64(val);
208 }
209 
210 // Logical operators.
211 
212 inline uint128 operator~(const uint128& val) {
213   return uint128(~Uint128High64(val), ~Uint128Low64(val));
214 }
215 
216 #define LOGIC128(op)                                                   \
217   inline uint128 operator op(const uint128& lhs, const uint128& rhs) { \
218     return uint128(                                                    \
219         Uint128High64(lhs) op Uint128High64(rhs),                      \
220         Uint128Low64(lhs) op Uint128Low64(rhs));                       \
221   }
222 
223 LOGIC128(|)
224 LOGIC128(&)
225 LOGIC128(^)
226 
227 #undef LOGIC128
228 
229 #define LOGICASSIGN128(op)                                              \
230   C10_API inline uint128& uint128::operator op(const uint128 & other) { \
231     hi_ op other.hi_;                                                   \
232     lo_ op other.lo_;                                                   \
233     return *this;                                                       \
234   }
235 
236 LOGICASSIGN128(|=)
237 LOGICASSIGN128(&=)
238 LOGICASSIGN128(^=)
239 
240 #undef LOGICASSIGN128
241 
242 // Shift operators.
243 
244 inline uint128 operator<<(const uint128& val, int amount) {
245   // uint64_t shifts of >= 64 are undefined, so we will need some
246   // special-casing.
247   if (amount < 64) {
248     if (amount == 0) {
249       return val;
250     }
251     uint64_t new_hi =
252         (Uint128High64(val) << amount) | (Uint128Low64(val) >> (64 - amount));
253     uint64_t new_lo = Uint128Low64(val) << amount;
254     return uint128(new_hi, new_lo);
255   } else if (amount < 128) {
256     return uint128(Uint128Low64(val) << (amount - 64), 0);
257   } else {
258     return uint128(0, 0);
259   }
260 }
261 
262 inline uint128 operator>>(const uint128& val, int amount) {
263   // uint64_t shifts of >= 64 are undefined, so we will need some
264   // special-casing.
265   if (amount < 64) {
266     if (amount == 0) {
267       return val;
268     }
269     uint64_t new_hi = Uint128High64(val) >> amount;
270     uint64_t new_lo =
271         (Uint128Low64(val) >> amount) | (Uint128High64(val) << (64 - amount));
272     return uint128(new_hi, new_lo);
273   } else if (amount < 128) {
274     return uint128(0, Uint128High64(val) >> (amount - 64));
275   } else {
276     return uint128(0, 0);
277   }
278 }
279 
280 inline uint128& operator<<=(uint128& self, int amount) {
281   // uint64_t shifts of >= 64 are undefined, so we will need some
282   // special-casing.
283   if (amount < 64) {
284     if (amount != 0) {
285       self.hi_ = (self.hi_ << amount) | (self.lo_ >> (64 - amount));
286       self.lo_ = self.lo_ << amount;
287     }
288   } else if (amount < 128) {
289     self.hi_ = self.lo_ << (amount - 64);
290     self.lo_ = 0;
291   } else {
292     self.hi_ = 0;
293     self.lo_ = 0;
294   }
295   return self;
296 }
297 
298 C10_API inline uint128& uint128::operator>>=(int amount) {
299   // uint64_t shifts of >= 64 are undefined, so we will need some
300   // special-casing.
301   if (amount < 64) {
302     if (amount != 0) {
303       lo_ = (lo_ >> amount) | (hi_ << (64 - amount));
304       hi_ = hi_ >> amount;
305     }
306   } else if (amount < 128) {
307     lo_ = hi_ >> (amount - 64);
308     hi_ = 0;
309   } else {
310     lo_ = 0;
311     hi_ = 0;
312   }
313   return *this;
314 }
315 
316 inline uint128 operator+(const uint128& lhs, const uint128& rhs) {
317   return uint128(lhs) += rhs;
318 }
319 
320 inline uint128 operator-(const uint128& lhs, const uint128& rhs) {
321   return uint128(lhs) -= rhs;
322 }
323 
324 inline uint128 operator*(const uint128& lhs, const uint128& rhs) {
325   return uint128(lhs) *= rhs;
326 }
327 
328 inline uint128 operator/(const uint128& lhs, const uint128& rhs) {
329   return uint128(lhs) /= rhs;
330 }
331 
332 inline uint128 operator%(const uint128& lhs, const uint128& rhs) {
333   return uint128(lhs) %= rhs;
334 }
335 
336 C10_API inline uint128& uint128::operator+=(const uint128& b) {
337   hi_ += b.hi_;
338   uint64_t lolo = lo_ + b.lo_;
339   if (lolo < lo_)
340     ++hi_;
341   lo_ = lolo;
342   return *this;
343 }
344 
345 C10_API inline uint128& uint128::operator-=(const uint128& b) {
346   hi_ -= b.hi_;
347   if (b.lo_ > lo_)
348     --hi_;
349   lo_ -= b.lo_;
350   return *this;
351 }
352 
353 C10_API inline uint128& uint128::operator*=(const uint128& b) {
354   uint64_t a96 = hi_ >> 32;
355   uint64_t a64 = hi_ & 0xffffffffu;
356   uint64_t a32 = lo_ >> 32;
357   uint64_t a00 = lo_ & 0xffffffffu;
358   uint64_t b96 = b.hi_ >> 32;
359   uint64_t b64 = b.hi_ & 0xffffffffu;
360   uint64_t b32 = b.lo_ >> 32;
361   uint64_t b00 = b.lo_ & 0xffffffffu;
362   // multiply [a96 .. a00] x [b96 .. b00]
363   // terms higher than c96 disappear off the high side
364   // terms c96 and c64 are safe to ignore carry bit
365   uint64_t c96 = a96 * b00 + a64 * b32 + a32 * b64 + a00 * b96;
366   uint64_t c64 = a64 * b00 + a32 * b32 + a00 * b64;
367   this->hi_ = (c96 << 32) + c64;
368   this->lo_ = 0;
369   // add terms after this one at a time to capture carry
370   *this += uint128(a32 * b00) << 32;
371   *this += uint128(a00 * b32) << 32;
372   *this += a00 * b00;
373   return *this;
374 }
375 
376 C10_API inline uint128 uint128::operator++(int) {
377   uint128 tmp(*this);
378   *this += 1;
379   return tmp;
380 }
381 
382 C10_API inline uint128 uint128::operator--(int) {
383   uint128 tmp(*this);
384   *this -= 1;
385   return tmp;
386 }
387 
388 C10_API inline uint128& uint128::operator++() {
389   *this += 1;
390   return *this;
391 }
392 
393 C10_API inline uint128& uint128::operator--() {
394   *this -= 1;
395   return *this;
396 }
397 
398 } // namespace c10
399