xref: /aosp_15_r20/external/pytorch/c10/util/Bitset.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
4*da0073e9SAndroid Build Coastguard Worker #if defined(_MSC_VER)
5*da0073e9SAndroid Build Coastguard Worker #include <intrin.h>
6*da0073e9SAndroid Build Coastguard Worker #endif
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace c10::utils {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker /**
11*da0073e9SAndroid Build Coastguard Worker  * This is a simple bitset class with sizeof(long long int) bits.
12*da0073e9SAndroid Build Coastguard Worker  * You can set bits, unset bits, query bits by index,
13*da0073e9SAndroid Build Coastguard Worker  * and query for the first set bit.
14*da0073e9SAndroid Build Coastguard Worker  * Before using this class, please also take a look at std::bitset,
15*da0073e9SAndroid Build Coastguard Worker  * which has more functionality and is more generic. It is probably
16*da0073e9SAndroid Build Coastguard Worker  * a better fit for your use case. The sole reason for c10::utils::bitset
17*da0073e9SAndroid Build Coastguard Worker  * to exist is that std::bitset misses a find_first_set() method.
18*da0073e9SAndroid Build Coastguard Worker  */
19*da0073e9SAndroid Build Coastguard Worker struct bitset final {
20*da0073e9SAndroid Build Coastguard Worker  private:
21*da0073e9SAndroid Build Coastguard Worker #if defined(_MSC_VER)
22*da0073e9SAndroid Build Coastguard Worker   // MSVCs _BitScanForward64 expects int64_t
23*da0073e9SAndroid Build Coastguard Worker   using bitset_type = int64_t;
24*da0073e9SAndroid Build Coastguard Worker #else
25*da0073e9SAndroid Build Coastguard Worker   // POSIX ffsll expects long long int
26*da0073e9SAndroid Build Coastguard Worker   using bitset_type = long long int;
27*da0073e9SAndroid Build Coastguard Worker #endif
28*da0073e9SAndroid Build Coastguard Worker  public:
NUM_BITSfinal29*da0073e9SAndroid Build Coastguard Worker   static constexpr size_t NUM_BITS() {
30*da0073e9SAndroid Build Coastguard Worker     return 8 * sizeof(bitset_type);
31*da0073e9SAndroid Build Coastguard Worker   }
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker   constexpr bitset() noexcept = default;
34*da0073e9SAndroid Build Coastguard Worker   constexpr bitset(const bitset&) noexcept = default;
35*da0073e9SAndroid Build Coastguard Worker   constexpr bitset(bitset&&) noexcept = default;
36*da0073e9SAndroid Build Coastguard Worker   // there is an issure for gcc 5.3.0 when define default function as constexpr
37*da0073e9SAndroid Build Coastguard Worker   // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
38*da0073e9SAndroid Build Coastguard Worker   bitset& operator=(const bitset&) noexcept = default;
39*da0073e9SAndroid Build Coastguard Worker   bitset& operator=(bitset&&) noexcept = default;
40*da0073e9SAndroid Build Coastguard Worker 
setfinal41*da0073e9SAndroid Build Coastguard Worker   constexpr void set(size_t index) noexcept {
42*da0073e9SAndroid Build Coastguard Worker     bitset_ |= (static_cast<long long int>(1) << index);
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker 
unsetfinal45*da0073e9SAndroid Build Coastguard Worker   constexpr void unset(size_t index) noexcept {
46*da0073e9SAndroid Build Coastguard Worker     bitset_ &= ~(static_cast<long long int>(1) << index);
47*da0073e9SAndroid Build Coastguard Worker   }
48*da0073e9SAndroid Build Coastguard Worker 
getfinal49*da0073e9SAndroid Build Coastguard Worker   constexpr bool get(size_t index) const noexcept {
50*da0073e9SAndroid Build Coastguard Worker     return bitset_ & (static_cast<long long int>(1) << index);
51*da0073e9SAndroid Build Coastguard Worker   }
52*da0073e9SAndroid Build Coastguard Worker 
is_entirely_unsetfinal53*da0073e9SAndroid Build Coastguard Worker   constexpr bool is_entirely_unset() const noexcept {
54*da0073e9SAndroid Build Coastguard Worker     return 0 == bitset_;
55*da0073e9SAndroid Build Coastguard Worker   }
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   // Call the given functor with the index of each bit that is set
58*da0073e9SAndroid Build Coastguard Worker   template <class Func>
for_each_set_bitfinal59*da0073e9SAndroid Build Coastguard Worker   void for_each_set_bit(Func&& func) const {
60*da0073e9SAndroid Build Coastguard Worker     bitset cur = *this;
61*da0073e9SAndroid Build Coastguard Worker     size_t index = cur.find_first_set();
62*da0073e9SAndroid Build Coastguard Worker     while (0 != index) {
63*da0073e9SAndroid Build Coastguard Worker       // -1 because find_first_set() is not one-indexed.
64*da0073e9SAndroid Build Coastguard Worker       index -= 1;
65*da0073e9SAndroid Build Coastguard Worker       func(index);
66*da0073e9SAndroid Build Coastguard Worker       cur.unset(index);
67*da0073e9SAndroid Build Coastguard Worker       index = cur.find_first_set();
68*da0073e9SAndroid Build Coastguard Worker     }
69*da0073e9SAndroid Build Coastguard Worker   }
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker  private:
72*da0073e9SAndroid Build Coastguard Worker   // Return the index of the first set bit. The returned index is one-indexed
73*da0073e9SAndroid Build Coastguard Worker   // (i.e. if the very first bit is set, this function returns '1'), and a
74*da0073e9SAndroid Build Coastguard Worker   // return of '0' means that there was no bit set.
find_first_setfinal75*da0073e9SAndroid Build Coastguard Worker   size_t find_first_set() const {
76*da0073e9SAndroid Build Coastguard Worker #if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
77*da0073e9SAndroid Build Coastguard Worker     unsigned long result;
78*da0073e9SAndroid Build Coastguard Worker     bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
79*da0073e9SAndroid Build Coastguard Worker     if (!has_bits_set) {
80*da0073e9SAndroid Build Coastguard Worker       return 0;
81*da0073e9SAndroid Build Coastguard Worker     }
82*da0073e9SAndroid Build Coastguard Worker     return result + 1;
83*da0073e9SAndroid Build Coastguard Worker #elif defined(_MSC_VER) && defined(_M_IX86)
84*da0073e9SAndroid Build Coastguard Worker     unsigned long result;
85*da0073e9SAndroid Build Coastguard Worker     if (static_cast<uint32_t>(bitset_) != 0) {
86*da0073e9SAndroid Build Coastguard Worker       bool has_bits_set =
87*da0073e9SAndroid Build Coastguard Worker           (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
88*da0073e9SAndroid Build Coastguard Worker       if (!has_bits_set) {
89*da0073e9SAndroid Build Coastguard Worker         return 0;
90*da0073e9SAndroid Build Coastguard Worker       }
91*da0073e9SAndroid Build Coastguard Worker       return result + 1;
92*da0073e9SAndroid Build Coastguard Worker     } else {
93*da0073e9SAndroid Build Coastguard Worker       bool has_bits_set =
94*da0073e9SAndroid Build Coastguard Worker           (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
95*da0073e9SAndroid Build Coastguard Worker       if (!has_bits_set) {
96*da0073e9SAndroid Build Coastguard Worker         return 32;
97*da0073e9SAndroid Build Coastguard Worker       }
98*da0073e9SAndroid Build Coastguard Worker       return result + 33;
99*da0073e9SAndroid Build Coastguard Worker     }
100*da0073e9SAndroid Build Coastguard Worker #else
101*da0073e9SAndroid Build Coastguard Worker     return __builtin_ffsll(bitset_);
102*da0073e9SAndroid Build Coastguard Worker #endif
103*da0073e9SAndroid Build Coastguard Worker   }
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker   friend bool operator==(bitset lhs, bitset rhs) noexcept {
106*da0073e9SAndroid Build Coastguard Worker     return lhs.bitset_ == rhs.bitset_;
107*da0073e9SAndroid Build Coastguard Worker   }
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker   bitset_type bitset_{0};
110*da0073e9SAndroid Build Coastguard Worker };
111*da0073e9SAndroid Build Coastguard Worker 
112*da0073e9SAndroid Build Coastguard Worker inline bool operator!=(bitset lhs, bitset rhs) noexcept {
113*da0073e9SAndroid Build Coastguard Worker   return !(lhs == rhs);
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker } // namespace c10::utils
117