xref: /aosp_15_r20/external/cronet/base/containers/enum_set.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7 
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/memory/raw_ptr.h"
18 #include "build/build_config.h"
19 
20 namespace base {
21 
22 // Forward declarations needed for friend declarations.
23 template <typename E, E MinEnumValue, E MaxEnumValue>
24 class EnumSet;
25 
26 template <typename E, E Min, E Max>
27 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
28                                      EnumSet<E, Min, Max> set2);
29 
30 template <typename E, E Min, E Max>
31 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
32                                             EnumSet<E, Min, Max> set2);
33 
34 template <typename E, E Min, E Max>
35 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
36                                           EnumSet<E, Min, Max> set2);
37 
38 // An EnumSet is a set that can hold enum values between a min and a
39 // max value (inclusive of both).  It's essentially a wrapper around
40 // std::bitset<> with stronger type enforcement, more descriptive
41 // member function names, and an iterator interface.
42 //
43 // If you're working with enums with a small number of possible values
44 // (say, fewer than 64), you can efficiently pass around an EnumSet
45 // for that enum around by value.
46 
47 template <typename E, E MinEnumValue, E MaxEnumValue>
48 class EnumSet {
49  private:
50   static_assert(
51       std::is_enum_v<E>,
52       "First template parameter of EnumSet must be an enumeration type");
53   using enum_underlying_type = std::underlying_type_t<E>;
54 
InRange(E value)55   static constexpr bool InRange(E value) {
56     return (value >= MinEnumValue) && (value <= MaxEnumValue);
57   }
58 
GetUnderlyingValue(E value)59   static constexpr enum_underlying_type GetUnderlyingValue(E value) {
60     return static_cast<enum_underlying_type>(value);
61   }
62 
63  public:
64   using EnumType = E;
65   static const E kMinValue = MinEnumValue;
66   static const E kMaxValue = MaxEnumValue;
67   static const size_t kValueCount =
68       GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
69 
70   static_assert(kMinValue <= kMaxValue,
71                 "min value must be no greater than max value");
72 
73  private:
74   // Declaration needed by Iterator.
75   using EnumBitSet = std::bitset<kValueCount>;
76 
77  public:
78   // Iterator is a forward-only read-only iterator for EnumSet. It follows the
79   // common STL input iterator interface (like std::unordered_set).
80   //
81   // Example usage, using a range-based for loop:
82   //
83   // EnumSet<SomeType> enums;
84   // for (SomeType val : enums) {
85   //   Process(val);
86   // }
87   //
88   // Or using an explicit iterator (not recommended):
89   //
90   // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
91   //   Process(*it);
92   // }
93   //
94   // The iterator must not be outlived by the set. In particular, the following
95   // is an error:
96   //
97   // EnumSet<...> SomeFn() { ... }
98   //
99   // /* ERROR */
100   // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
101   //
102   // Also, there are no guarantees as to what will happen if you
103   // modify an EnumSet while traversing it with an iterator.
104   class Iterator {
105    public:
106     using value_type = EnumType;
107     using size_type = size_t;
108     using difference_type = ptrdiff_t;
109     using pointer = EnumType*;
110     using reference = EnumType&;
111     using iterator_category = std::forward_iterator_tag;
112 
Iterator()113     Iterator() : enums_(nullptr), i_(kValueCount) {}
114     ~Iterator() = default;
115 
116     friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
117       return lhs.i_ == rhs.i_;
118     }
119 
120     value_type operator*() const {
121       DCHECK(Good());
122       return FromIndex(i_);
123     }
124 
125     Iterator& operator++() {
126       DCHECK(Good());
127       // If there are no more set elements in the bitset, this will result in an
128       // index equal to kValueCount, which is equivalent to EnumSet.end().
129       i_ = FindNext(i_ + 1);
130 
131       return *this;
132     }
133 
134     Iterator operator++(int) {
135       DCHECK(Good());
136       Iterator old(*this);
137 
138       // If there are no more set elements in the bitset, this will result in an
139       // index equal to kValueCount, which is equivalent to EnumSet.end().
140       i_ = FindNext(i_ + 1);
141 
142       return std::move(old);
143     }
144 
145    private:
146     friend Iterator EnumSet::begin() const;
147 
Iterator(const EnumBitSet & enums)148     explicit Iterator(const EnumBitSet& enums)
149         : enums_(&enums), i_(FindNext(0)) {}
150 
151     // Returns true iff the iterator points to an EnumSet and it
152     // hasn't yet traversed the EnumSet entirely.
Good()153     bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
154 
FindNext(size_t i)155     size_t FindNext(size_t i) {
156       while ((i < kValueCount) && !enums_->test(i)) {
157         ++i;
158       }
159       return i;
160     }
161 
162     const raw_ptr<const EnumBitSet> enums_;
163     size_t i_;
164   };
165 
166   EnumSet() = default;
167 
168   ~EnumSet() = default;
169 
EnumSet(std::initializer_list<E> values)170   constexpr EnumSet(std::initializer_list<E> values) {
171     if (std::is_constant_evaluated()) {
172       enums_ = bitstring(values);
173     } else {
174       for (E value : values) {
175         Put(value);
176       }
177     }
178   }
179 
180   // Returns an EnumSet with all values between kMinValue and kMaxValue, which
181   // also contains undefined enum values if the enum in question has gaps
182   // between kMinValue and kMaxValue.
All()183   static constexpr EnumSet All() {
184     if (std::is_constant_evaluated()) {
185       if (kValueCount == 0) {
186         return EnumSet();
187       }
188       // Since `1 << kValueCount` may trigger shift-count-overflow warning if
189       // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
190       // the bitmask will be constructed from two parts: the most significant
191       // bits and the remaining.
192       uint64_t mask = 1ULL << (kValueCount - 1);
193       return EnumSet(EnumBitSet(mask - 1 + mask));
194     } else {
195       // When `kValueCount` is greater than 64, we can't use the constexpr path,
196       // and we will build an `EnumSet` value by value.
197       EnumSet enum_set;
198       for (size_t value = 0; value < kValueCount; ++value) {
199         enum_set.Put(FromIndex(value));
200       }
201       return enum_set;
202     }
203   }
204 
205   // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)206   static constexpr EnumSet FromRange(E start, E end) {
207     CHECK_LE(start, end);
208     return EnumSet(EnumBitSet(
209         ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
210         (single_val_bitstring(end))));
211   }
212 
213   // Copy constructor and assignment welcome.
214 
215   // Bitmask operations.
216   //
217   // This bitmask is 0-based and the value of the Nth bit depends on whether
218   // the set contains an enum element of integer value N.
219   //
220   // These may only be used if Min >= 0 and Max < 64.
221 
222   // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)223   static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
224     static_assert(GetUnderlyingValue(kMaxValue) < 64,
225                   "The highest enum value must be < 64 for FromEnumBitmask ");
226     static_assert(GetUnderlyingValue(kMinValue) >= 0,
227                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
228     return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
229   }
230   // Returns a bitmask for the EnumSet.
ToEnumBitmask()231   uint64_t ToEnumBitmask() const {
232     static_assert(GetUnderlyingValue(kMaxValue) < 64,
233                   "The highest enum value must be < 64 for ToEnumBitmask ");
234     static_assert(GetUnderlyingValue(kMinValue) >= 0,
235                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
236     return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
237   }
238 
239   // Set operations.  Put, Retain, and Remove are basically
240   // self-mutating versions of Union, Intersection, and Difference
241   // (defined below).
242 
243   // Adds the given value (which must be in range) to our set.
Put(E value)244   void Put(E value) { enums_.set(ToIndex(value)); }
245 
246   // Adds all values in the given set to our set.
PutAll(EnumSet other)247   void PutAll(EnumSet other) { enums_ |= other.enums_; }
248 
249   // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)250   void PutRange(E start, E end) {
251     CHECK_LE(start, end);
252     size_t endIndexInclusive = ToIndex(end);
253     for (size_t current = ToIndex(start); current <= endIndexInclusive;
254          ++current) {
255       enums_.set(current);
256     }
257   }
258 
259   // There's no real need for a Retain(E) member function.
260 
261   // Removes all values not in the given set from our set.
RetainAll(EnumSet other)262   void RetainAll(EnumSet other) { enums_ &= other.enums_; }
263 
264   // If the given value is in range, removes it from our set.
Remove(E value)265   void Remove(E value) {
266     if (InRange(value)) {
267       enums_.reset(ToIndex(value));
268     }
269   }
270 
271   // Removes all values in the given set from our set.
RemoveAll(EnumSet other)272   void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
273 
274   // Removes all values from our set.
Clear()275   void Clear() { enums_.reset(); }
276 
277   // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)278   void PutOrRemove(E value, bool should_be_present) {
279     if (should_be_present) {
280       Put(value);
281     } else {
282       Remove(value);
283     }
284   }
285 
286   // Returns true iff the given value is in range and a member of our set.
Has(E value)287   constexpr bool Has(E value) const {
288     return InRange(value) && enums_[ToIndex(value)];
289   }
290 
291   // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)292   bool HasAll(EnumSet other) const {
293     return (enums_ & other.enums_) == other.enums_;
294   }
295 
296   // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)297   bool HasAny(EnumSet other) const {
298     return (enums_ & other.enums_).count() > 0;
299   }
300 
301   // Returns true iff our set is empty.
empty()302   bool empty() const { return !enums_.any(); }
303 
304   // Returns how many values our set has.
size()305   size_t size() const { return enums_.count(); }
306 
307   // Returns an iterator pointing to the first element (if any).
begin()308   Iterator begin() const { return Iterator(enums_); }
309 
310   // Returns an iterator that does not point to any element, but to the position
311   // that follows the last element in the set.
end()312   Iterator end() const { return Iterator(); }
313 
314   // Returns true iff our set and the given set contain exactly the same values.
315   friend bool operator==(const EnumSet&, const EnumSet&) = default;
316 
ToString()317   std::string ToString() const { return enums_.to_string(); }
318 
319  private:
320   friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
321                                                                 EnumSet set2);
322   friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
323       EnumSet set1,
324       EnumSet set2);
325   friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
326       EnumSet set1,
327       EnumSet set2);
328 
bitstring(const std::initializer_list<E> & values)329   static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
330     uint64_t result = 0;
331     for (E value : values) {
332       result |= single_val_bitstring(value);
333     }
334     return result;
335   }
336 
single_val_bitstring(E val)337   static constexpr uint64_t single_val_bitstring(E val) {
338     const uint64_t bitstring = 1;
339     const size_t shift_amount = ToIndex(val);
340     CHECK_LT(shift_amount, sizeof(bitstring) * 8);
341     return bitstring << shift_amount;
342   }
343 
344   // A bitset can't be constexpr constructed if it has size > 64, since the
345   // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
346   // can safely remove the constepxr qualifiers from this file, at the cost of
347   // some minor optimizations.
EnumSet(EnumBitSet enums)348   explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
349     if (std::is_constant_evaluated()) {
350       CHECK(kValueCount <= 64)
351           << "Max number of enum values is 64 for constexpr constructor";
352     }
353   }
354 
355   // Converts a value to/from an index into |enums_|.
ToIndex(E value)356   static constexpr size_t ToIndex(E value) {
357     CHECK(InRange(value));
358     return static_cast<size_t>(GetUnderlyingValue(value)) -
359            static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
360   }
361 
FromIndex(size_t i)362   static E FromIndex(size_t i) {
363     DCHECK_LT(i, kValueCount);
364     return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
365   }
366 
367   EnumBitSet enums_;
368 };
369 
370 template <typename E, E MinEnumValue, E MaxEnumValue>
371 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
372 
373 template <typename E, E MinEnumValue, E MaxEnumValue>
374 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
375 
376 template <typename E, E MinEnumValue, E MaxEnumValue>
377 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
378 
379 // The usual set operations.
380 
381 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)382 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
383                                      EnumSet<E, Min, Max> set2) {
384   return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
385 }
386 
387 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)388 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
389                                             EnumSet<E, Min, Max> set2) {
390   return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
391 }
392 
393 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)394 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
395                                           EnumSet<E, Min, Max> set2) {
396   return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
397 }
398 
399 }  // namespace base
400 
401 #endif  // BASE_CONTAINERS_ENUM_SET_H_
402