1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/memory/raw_ptr.h"
18 #include "build/build_config.h"
19
20 namespace base {
21
22 // Forward declarations needed for friend declarations.
23 template <typename E, E MinEnumValue, E MaxEnumValue>
24 class EnumSet;
25
26 template <typename E, E Min, E Max>
27 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
28 EnumSet<E, Min, Max> set2);
29
30 template <typename E, E Min, E Max>
31 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
32 EnumSet<E, Min, Max> set2);
33
34 template <typename E, E Min, E Max>
35 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
36 EnumSet<E, Min, Max> set2);
37
38 // An EnumSet is a set that can hold enum values between a min and a
39 // max value (inclusive of both). It's essentially a wrapper around
40 // std::bitset<> with stronger type enforcement, more descriptive
41 // member function names, and an iterator interface.
42 //
43 // If you're working with enums with a small number of possible values
44 // (say, fewer than 64), you can efficiently pass around an EnumSet
45 // for that enum around by value.
46
47 template <typename E, E MinEnumValue, E MaxEnumValue>
48 class EnumSet {
49 private:
50 static_assert(
51 std::is_enum_v<E>,
52 "First template parameter of EnumSet must be an enumeration type");
53 using enum_underlying_type = std::underlying_type_t<E>;
54
InRange(E value)55 static constexpr bool InRange(E value) {
56 return (value >= MinEnumValue) && (value <= MaxEnumValue);
57 }
58
GetUnderlyingValue(E value)59 static constexpr enum_underlying_type GetUnderlyingValue(E value) {
60 return static_cast<enum_underlying_type>(value);
61 }
62
63 public:
64 using EnumType = E;
65 static const E kMinValue = MinEnumValue;
66 static const E kMaxValue = MaxEnumValue;
67 static const size_t kValueCount =
68 GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
69
70 static_assert(kMinValue <= kMaxValue,
71 "min value must be no greater than max value");
72
73 private:
74 // Declaration needed by Iterator.
75 using EnumBitSet = std::bitset<kValueCount>;
76
77 public:
78 // Iterator is a forward-only read-only iterator for EnumSet. It follows the
79 // common STL input iterator interface (like std::unordered_set).
80 //
81 // Example usage, using a range-based for loop:
82 //
83 // EnumSet<SomeType> enums;
84 // for (SomeType val : enums) {
85 // Process(val);
86 // }
87 //
88 // Or using an explicit iterator (not recommended):
89 //
90 // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
91 // Process(*it);
92 // }
93 //
94 // The iterator must not be outlived by the set. In particular, the following
95 // is an error:
96 //
97 // EnumSet<...> SomeFn() { ... }
98 //
99 // /* ERROR */
100 // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
101 //
102 // Also, there are no guarantees as to what will happen if you
103 // modify an EnumSet while traversing it with an iterator.
104 class Iterator {
105 public:
106 using value_type = EnumType;
107 using size_type = size_t;
108 using difference_type = ptrdiff_t;
109 using pointer = EnumType*;
110 using reference = EnumType&;
111 using iterator_category = std::forward_iterator_tag;
112
Iterator()113 Iterator() : enums_(nullptr), i_(kValueCount) {}
114 ~Iterator() = default;
115
116 friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
117 return lhs.i_ == rhs.i_;
118 }
119
120 value_type operator*() const {
121 DCHECK(Good());
122 return FromIndex(i_);
123 }
124
125 Iterator& operator++() {
126 DCHECK(Good());
127 // If there are no more set elements in the bitset, this will result in an
128 // index equal to kValueCount, which is equivalent to EnumSet.end().
129 i_ = FindNext(i_ + 1);
130
131 return *this;
132 }
133
134 Iterator operator++(int) {
135 DCHECK(Good());
136 Iterator old(*this);
137
138 // If there are no more set elements in the bitset, this will result in an
139 // index equal to kValueCount, which is equivalent to EnumSet.end().
140 i_ = FindNext(i_ + 1);
141
142 return std::move(old);
143 }
144
145 private:
146 friend Iterator EnumSet::begin() const;
147
Iterator(const EnumBitSet & enums)148 explicit Iterator(const EnumBitSet& enums)
149 : enums_(&enums), i_(FindNext(0)) {}
150
151 // Returns true iff the iterator points to an EnumSet and it
152 // hasn't yet traversed the EnumSet entirely.
Good()153 bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
154
FindNext(size_t i)155 size_t FindNext(size_t i) {
156 while ((i < kValueCount) && !enums_->test(i)) {
157 ++i;
158 }
159 return i;
160 }
161
162 const raw_ptr<const EnumBitSet> enums_;
163 size_t i_;
164 };
165
166 EnumSet() = default;
167
168 ~EnumSet() = default;
169
EnumSet(std::initializer_list<E> values)170 constexpr EnumSet(std::initializer_list<E> values) {
171 if (std::is_constant_evaluated()) {
172 enums_ = bitstring(values);
173 } else {
174 for (E value : values) {
175 Put(value);
176 }
177 }
178 }
179
180 // Returns an EnumSet with all values between kMinValue and kMaxValue, which
181 // also contains undefined enum values if the enum in question has gaps
182 // between kMinValue and kMaxValue.
All()183 static constexpr EnumSet All() {
184 if (std::is_constant_evaluated()) {
185 if (kValueCount == 0) {
186 return EnumSet();
187 }
188 // Since `1 << kValueCount` may trigger shift-count-overflow warning if
189 // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
190 // the bitmask will be constructed from two parts: the most significant
191 // bits and the remaining.
192 uint64_t mask = 1ULL << (kValueCount - 1);
193 return EnumSet(EnumBitSet(mask - 1 + mask));
194 } else {
195 // When `kValueCount` is greater than 64, we can't use the constexpr path,
196 // and we will build an `EnumSet` value by value.
197 EnumSet enum_set;
198 for (size_t value = 0; value < kValueCount; ++value) {
199 enum_set.Put(FromIndex(value));
200 }
201 return enum_set;
202 }
203 }
204
205 // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)206 static constexpr EnumSet FromRange(E start, E end) {
207 CHECK_LE(start, end);
208 return EnumSet(EnumBitSet(
209 ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
210 (single_val_bitstring(end))));
211 }
212
213 // Copy constructor and assignment welcome.
214
215 // Bitmask operations.
216 //
217 // This bitmask is 0-based and the value of the Nth bit depends on whether
218 // the set contains an enum element of integer value N.
219 //
220 // These may only be used if Min >= 0 and Max < 64.
221
222 // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)223 static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
224 static_assert(GetUnderlyingValue(kMaxValue) < 64,
225 "The highest enum value must be < 64 for FromEnumBitmask ");
226 static_assert(GetUnderlyingValue(kMinValue) >= 0,
227 "The lowest enum value must be >= 0 for FromEnumBitmask ");
228 return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
229 }
230 // Returns a bitmask for the EnumSet.
ToEnumBitmask()231 uint64_t ToEnumBitmask() const {
232 static_assert(GetUnderlyingValue(kMaxValue) < 64,
233 "The highest enum value must be < 64 for ToEnumBitmask ");
234 static_assert(GetUnderlyingValue(kMinValue) >= 0,
235 "The lowest enum value must be >= 0 for FromEnumBitmask ");
236 return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
237 }
238
239 // Set operations. Put, Retain, and Remove are basically
240 // self-mutating versions of Union, Intersection, and Difference
241 // (defined below).
242
243 // Adds the given value (which must be in range) to our set.
Put(E value)244 void Put(E value) { enums_.set(ToIndex(value)); }
245
246 // Adds all values in the given set to our set.
PutAll(EnumSet other)247 void PutAll(EnumSet other) { enums_ |= other.enums_; }
248
249 // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)250 void PutRange(E start, E end) {
251 CHECK_LE(start, end);
252 size_t endIndexInclusive = ToIndex(end);
253 for (size_t current = ToIndex(start); current <= endIndexInclusive;
254 ++current) {
255 enums_.set(current);
256 }
257 }
258
259 // There's no real need for a Retain(E) member function.
260
261 // Removes all values not in the given set from our set.
RetainAll(EnumSet other)262 void RetainAll(EnumSet other) { enums_ &= other.enums_; }
263
264 // If the given value is in range, removes it from our set.
Remove(E value)265 void Remove(E value) {
266 if (InRange(value)) {
267 enums_.reset(ToIndex(value));
268 }
269 }
270
271 // Removes all values in the given set from our set.
RemoveAll(EnumSet other)272 void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
273
274 // Removes all values from our set.
Clear()275 void Clear() { enums_.reset(); }
276
277 // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)278 void PutOrRemove(E value, bool should_be_present) {
279 if (should_be_present) {
280 Put(value);
281 } else {
282 Remove(value);
283 }
284 }
285
286 // Returns true iff the given value is in range and a member of our set.
Has(E value)287 constexpr bool Has(E value) const {
288 return InRange(value) && enums_[ToIndex(value)];
289 }
290
291 // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)292 bool HasAll(EnumSet other) const {
293 return (enums_ & other.enums_) == other.enums_;
294 }
295
296 // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)297 bool HasAny(EnumSet other) const {
298 return (enums_ & other.enums_).count() > 0;
299 }
300
301 // Returns true iff our set is empty.
empty()302 bool empty() const { return !enums_.any(); }
303
304 // Returns how many values our set has.
size()305 size_t size() const { return enums_.count(); }
306
307 // Returns an iterator pointing to the first element (if any).
begin()308 Iterator begin() const { return Iterator(enums_); }
309
310 // Returns an iterator that does not point to any element, but to the position
311 // that follows the last element in the set.
end()312 Iterator end() const { return Iterator(); }
313
314 // Returns true iff our set and the given set contain exactly the same values.
315 friend bool operator==(const EnumSet&, const EnumSet&) = default;
316
ToString()317 std::string ToString() const { return enums_.to_string(); }
318
319 private:
320 friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
321 EnumSet set2);
322 friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
323 EnumSet set1,
324 EnumSet set2);
325 friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
326 EnumSet set1,
327 EnumSet set2);
328
bitstring(const std::initializer_list<E> & values)329 static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
330 uint64_t result = 0;
331 for (E value : values) {
332 result |= single_val_bitstring(value);
333 }
334 return result;
335 }
336
single_val_bitstring(E val)337 static constexpr uint64_t single_val_bitstring(E val) {
338 const uint64_t bitstring = 1;
339 const size_t shift_amount = ToIndex(val);
340 CHECK_LT(shift_amount, sizeof(bitstring) * 8);
341 return bitstring << shift_amount;
342 }
343
344 // A bitset can't be constexpr constructed if it has size > 64, since the
345 // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
346 // can safely remove the constepxr qualifiers from this file, at the cost of
347 // some minor optimizations.
EnumSet(EnumBitSet enums)348 explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
349 if (std::is_constant_evaluated()) {
350 CHECK(kValueCount <= 64)
351 << "Max number of enum values is 64 for constexpr constructor";
352 }
353 }
354
355 // Converts a value to/from an index into |enums_|.
ToIndex(E value)356 static constexpr size_t ToIndex(E value) {
357 CHECK(InRange(value));
358 return static_cast<size_t>(GetUnderlyingValue(value)) -
359 static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
360 }
361
FromIndex(size_t i)362 static E FromIndex(size_t i) {
363 DCHECK_LT(i, kValueCount);
364 return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
365 }
366
367 EnumBitSet enums_;
368 };
369
370 template <typename E, E MinEnumValue, E MaxEnumValue>
371 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
372
373 template <typename E, E MinEnumValue, E MaxEnumValue>
374 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
375
376 template <typename E, E MinEnumValue, E MaxEnumValue>
377 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
378
379 // The usual set operations.
380
381 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)382 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
383 EnumSet<E, Min, Max> set2) {
384 return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
385 }
386
387 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)388 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
389 EnumSet<E, Min, Max> set2) {
390 return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
391 }
392
393 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)394 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
395 EnumSet<E, Min, Max> set2) {
396 return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
397 }
398
399 } // namespace base
400
401 #endif // BASE_CONTAINERS_ENUM_SET_H_
402