xref: /aosp_15_r20/external/pytorch/c10/util/sparse_bitset.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SparseBitVector class.  See the doxygen comment for
10 // SparseBitVector for more details on the algorithm used.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #pragma once
15 #include <c10/macros/Macros.h>
16 #include <c10/util/llvmMathExtras.h>
17 #include <array>
18 #include <cassert>
19 #include <climits>
20 #include <iterator>
21 #include <list>
22 #include <ostream>
23 
24 namespace c10 {
25 
26 /// SparseBitVector is an implementation of a bitvector that is sparse by only
27 /// storing the elements that have non-zero bits set.  In order to make this
28 /// fast for the most common cases, SparseBitVector is implemented as a linked
29 /// list of SparseBitVectorElements.  We maintain a pointer to the last
30 /// SparseBitVectorElement accessed (in the form of a list iterator), in order
31 /// to make multiple in-order test/set constant time after the first one is
32 /// executed.  Note that using vectors to store SparseBitVectorElement's does
33 /// not work out very well because it causes insertion in the middle to take
34 /// enormous amounts of time with a large amount of bits.  Other structures that
35 /// have better worst cases for insertion in the middle (various balanced trees,
36 /// etc) do not perform as well in practice as a linked list with this iterator
37 /// kept up to date.  They are also significantly more memory intensive.
38 
39 template <unsigned ElementSize = 128>
40 struct SparseBitVectorElement {
41  public:
42   using BitWord = unsigned long;
43   using size_type = unsigned;
44   enum {
45     BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
46     BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
47     BITS_PER_ELEMENT = ElementSize
48   };
49 
50  private:
51   // Index of Element in terms of where first bit starts.
52   unsigned ElementIndex;
53   std::array<BitWord, BITWORDS_PER_ELEMENT> Bits{};
54 
SparseBitVectorElementSparseBitVectorElement55   SparseBitVectorElement() : ElementIndex(~0U) {}
56 
57  public:
SparseBitVectorElementSparseBitVectorElement58   explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {}
59 
60   // Comparison.
61   bool operator==(const SparseBitVectorElement& RHS) const {
62     if (ElementIndex != RHS.ElementIndex)
63       return false;
64     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
65       if (Bits[i] != RHS.Bits[i])
66         return false;
67     return true;
68   }
69 
70   bool operator!=(const SparseBitVectorElement& RHS) const {
71     return !(*this == RHS);
72   }
73 
74   // Return the bits that make up word Idx in our element.
wordSparseBitVectorElement75   BitWord word(unsigned Idx) const {
76     assert(Idx < BITWORDS_PER_ELEMENT);
77     return Bits[Idx];
78   }
79 
indexSparseBitVectorElement80   unsigned index() const {
81     return ElementIndex;
82   }
83 
emptySparseBitVectorElement84   bool empty() const {
85     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
86       if (Bits[i])
87         return false;
88     return true;
89   }
90 
setSparseBitVectorElement91   void set(unsigned Idx) {
92     Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
93   }
94 
test_and_setSparseBitVectorElement95   bool test_and_set(unsigned Idx) {
96     bool old = test(Idx);
97     if (!old) {
98       set(Idx);
99       return true;
100     }
101     return false;
102   }
103 
resetSparseBitVectorElement104   void reset(unsigned Idx) {
105     Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
106   }
107 
testSparseBitVectorElement108   bool test(unsigned Idx) const {
109     return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
110   }
111 
countSparseBitVectorElement112   size_type count() const {
113     unsigned NumBits = 0;
114     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
115       NumBits += llvm::countPopulation(Bits[i]);
116     return NumBits;
117   }
118 
119   /// find_first - Returns the index of the first set bit.
find_firstSparseBitVectorElement120   int find_first() const {
121     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
122       if (Bits[i] != 0)
123         return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
124     throw std::runtime_error("Illegal empty element");
125   }
126 
127   /// find_last - Returns the index of the last set bit.
find_lastSparseBitVectorElement128   int find_last() const {
129     for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
130       unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
131       if (Bits[Idx] != 0)
132         return Idx * BITWORD_SIZE + BITWORD_SIZE -
133             llvm::countLeadingZeros(Bits[Idx]);
134     }
135     throw std::runtime_error("Illegal empty element");
136   }
137 
138   /// find_next - Returns the index of the next set bit starting from the
139   /// "Curr" bit. Returns -1 if the next set bit is not found.
find_nextSparseBitVectorElement140   int find_next(unsigned Curr) const {
141     if (Curr >= BITS_PER_ELEMENT)
142       return -1;
143 
144     unsigned WordPos = Curr / BITWORD_SIZE;
145     unsigned BitPos = Curr % BITWORD_SIZE;
146     BitWord Copy = Bits[WordPos];
147     assert(
148         WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");
149 
150     // Mask off previous bits.
151     Copy &= ~0UL << BitPos;
152 
153     if (Copy != 0)
154       return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
155 
156     // Check subsequent words.
157     for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
158       if (Bits[i] != 0)
159         return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
160     return -1;
161   }
162 
163   // Union this element with RHS and return true if this one changed.
unionWithSparseBitVectorElement164   bool unionWith(const SparseBitVectorElement& RHS) {
165     bool changed = false;
166     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
167       BitWord old = changed ? 0 : Bits[i];
168 
169       Bits[i] |= RHS.Bits[i];
170       if (!changed && old != Bits[i])
171         changed = true;
172     }
173     return changed;
174   }
175 
176   // Return true if we have any bits in common with RHS
intersectsSparseBitVectorElement177   bool intersects(const SparseBitVectorElement& RHS) const {
178     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
179       if (RHS.Bits[i] & Bits[i])
180         return true;
181     }
182     return false;
183   }
184 
185   // Intersect this Element with RHS and return true if this one changed.
186   // BecameZero is set to true if this element became all-zero bits.
intersectWithSparseBitVectorElement187   bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
188     bool changed = false;
189     bool allzero = true;
190 
191     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
192       BitWord old = changed ? 0 : Bits[i];
193 
194       Bits[i] &= RHS.Bits[i];
195       if (Bits[i] != 0)
196         allzero = false;
197 
198       if (!changed && old != Bits[i])
199         changed = true;
200     }
201     BecameZero = allzero;
202     return changed;
203   }
204 
205   // Intersect this Element with the complement of RHS and return true if this
206   // one changed.  BecameZero is set to true if this element became all-zero
207   // bits.
intersectWithComplementSparseBitVectorElement208   bool intersectWithComplement(
209       const SparseBitVectorElement& RHS,
210       bool& BecameZero) {
211     bool changed = false;
212     bool allzero = true;
213 
214     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
215       BitWord old = changed ? 0 : Bits[i];
216 
217       Bits[i] &= ~RHS.Bits[i];
218       if (Bits[i] != 0)
219         allzero = false;
220 
221       if (!changed && old != Bits[i])
222         changed = true;
223     }
224     BecameZero = allzero;
225     return changed;
226   }
227 
228   // Three argument version of intersectWithComplement that intersects
229   // RHS1 & ~RHS2 into this element
intersectWithComplementSparseBitVectorElement230   void intersectWithComplement(
231       const SparseBitVectorElement& RHS1,
232       const SparseBitVectorElement& RHS2,
233       bool& BecameZero) {
234     bool allzero = true;
235 
236     for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
237       Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
238       if (Bits[i] != 0)
239         allzero = false;
240     }
241     BecameZero = allzero;
242   }
243 };
244 
245 template <unsigned ElementSize = 128>
246 class SparseBitVector {
247   using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
248   using ElementListIter = typename ElementList::iterator;
249   using ElementListConstIter = typename ElementList::const_iterator;
250   enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };
251 
252   ElementList Elements;
253   // Pointer to our current Element. This has no visible effect on the external
254   // state of a SparseBitVector, it's just used to improve performance in the
255   // common case of testing/modifying bits with similar indices.
256   mutable ElementListIter CurrElementIter;
257 
258   // This is like std::lower_bound, except we do linear searching from the
259   // current position.
FindLowerBoundImpl(unsigned ElementIndex)260   ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
261     // We cache a non-const iterator so we're forced to resort to const_cast to
262     // get the begin/end in the case where 'this' is const. To avoid duplication
263     // of code with the only difference being whether the const cast is present
264     // 'this' is always const in this particular function and we sort out the
265     // difference in FindLowerBound and FindLowerBoundConst.
266     ElementListIter Begin =
267         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
268         const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
269     ElementListIter End =
270         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
271         const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();
272 
273     if (Elements.empty()) {
274       CurrElementIter = Begin;
275       return CurrElementIter;
276     }
277 
278     // Make sure our current iterator is valid.
279     if (CurrElementIter == End)
280       --CurrElementIter;
281 
282     // Search from our current iterator, either backwards or forwards,
283     // depending on what element we are looking for.
284     ElementListIter ElementIter = CurrElementIter;
285     if (CurrElementIter->index() == ElementIndex) {
286       return ElementIter;
287     } else if (CurrElementIter->index() > ElementIndex) {
288       while (ElementIter != Begin && ElementIter->index() > ElementIndex)
289         --ElementIter;
290     } else {
291       while (ElementIter != End && ElementIter->index() < ElementIndex)
292         ++ElementIter;
293     }
294     CurrElementIter = ElementIter;
295     return ElementIter;
296   }
FindLowerBoundConst(unsigned ElementIndex)297   ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
298     return FindLowerBoundImpl(ElementIndex);
299   }
FindLowerBound(unsigned ElementIndex)300   ElementListIter FindLowerBound(unsigned ElementIndex) {
301     return FindLowerBoundImpl(ElementIndex);
302   }
303 
304   // Iterator to walk set bits in the bitmap.  This iterator is a lot uglier
305   // than it would be, in order to be efficient.
306   class SparseBitVectorIterator {
307    private:
308     bool AtEnd{false};
309 
310     const SparseBitVector<ElementSize>* BitVector = nullptr;
311 
312     // Current element inside of bitmap.
313     ElementListConstIter Iter;
314 
315     // Current bit number inside of our bitmap.
316     unsigned BitNumber{0};
317 
318     // Current word number inside of our element.
319     unsigned WordNumber{0};
320 
321     // Current bits from the element.
322     typename SparseBitVectorElement<ElementSize>::BitWord Bits{0};
323 
324     // Move our iterator to the first non-zero bit in the bitmap.
AdvanceToFirstNonZero()325     void AdvanceToFirstNonZero() {
326       if (AtEnd)
327         return;
328       if (BitVector->Elements.empty()) {
329         AtEnd = true;
330         return;
331       }
332       Iter = BitVector->Elements.begin();
333       BitNumber = Iter->index() * ElementSize;
334       unsigned BitPos = Iter->find_first();
335       BitNumber += BitPos;
336       WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
337       Bits = Iter->word(WordNumber);
338       Bits >>= BitPos % BITWORD_SIZE;
339     }
340 
341     // Move our iterator to the next non-zero bit.
AdvanceToNextNonZero()342     void AdvanceToNextNonZero() {
343       if (AtEnd)
344         return;
345 
346       while (Bits && !(Bits & 1)) {
347         Bits >>= 1;
348         BitNumber += 1;
349       }
350 
351       // See if we ran out of Bits in this word.
352       if (!Bits) {
353         int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize);
354         // If we ran out of set bits in this element, move to next element.
355         if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
356           ++Iter;
357           WordNumber = 0;
358 
359           // We may run out of elements in the bitmap.
360           if (Iter == BitVector->Elements.end()) {
361             AtEnd = true;
362             return;
363           }
364           // Set up for next non-zero word in bitmap.
365           BitNumber = Iter->index() * ElementSize;
366           NextSetBitNumber = Iter->find_first();
367           BitNumber += NextSetBitNumber;
368           WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
369           Bits = Iter->word(WordNumber);
370           Bits >>= NextSetBitNumber % BITWORD_SIZE;
371         } else {
372           WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
373           Bits = Iter->word(WordNumber);
374           Bits >>= NextSetBitNumber % BITWORD_SIZE;
375           BitNumber = Iter->index() * ElementSize;
376           BitNumber += NextSetBitNumber;
377         }
378       }
379     }
380 
381    public:
382     SparseBitVectorIterator() = default;
383 
384     SparseBitVectorIterator(
385         const SparseBitVector<ElementSize>* RHS,
386         bool end = false)
AtEnd(end)387         : AtEnd(end),
388           BitVector(RHS),
389           Iter(BitVector->Elements.begin()),
390           WordNumber(~0) {
391       AdvanceToFirstNonZero();
392     }
393 
394     // Preincrement.
395     inline SparseBitVectorIterator& operator++() {
396       ++BitNumber;
397       Bits >>= 1;
398       AdvanceToNextNonZero();
399       return *this;
400     }
401 
402     // Postincrement.
403     inline SparseBitVectorIterator operator++(int) {
404       SparseBitVectorIterator tmp = *this;
405       ++*this;
406       return tmp;
407     }
408 
409     // Return the current set bit number.
410     unsigned operator*() const {
411       return BitNumber;
412     }
413 
414     bool operator==(const SparseBitVectorIterator& RHS) const {
415       // If they are both at the end, ignore the rest of the fields.
416       if (AtEnd && RHS.AtEnd)
417         return true;
418       // Otherwise they are the same if they have the same bit number and
419       // bitmap.
420       return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
421     }
422 
423     bool operator!=(const SparseBitVectorIterator& RHS) const {
424       return !(*this == RHS);
425     }
426   };
427 
428  public:
429   using iterator = SparseBitVectorIterator;
430 
SparseBitVector()431   SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {}
432 
SparseBitVector(const SparseBitVector & RHS)433   SparseBitVector(const SparseBitVector& RHS)
434       : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {}
SparseBitVector(SparseBitVector && RHS)435   SparseBitVector(SparseBitVector&& RHS) noexcept
436       : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {}
437 
438   // Clear.
clear()439   void clear() {
440     Elements.clear();
441   }
442 
443   // Assignment
444   SparseBitVector& operator=(const SparseBitVector& RHS) {
445     if (this == &RHS)
446       return *this;
447 
448     Elements = RHS.Elements;
449     CurrElementIter = Elements.begin();
450     return *this;
451   }
452   SparseBitVector& operator=(SparseBitVector&& RHS) noexcept {
453     Elements = std::move(RHS.Elements);
454     CurrElementIter = Elements.begin();
455     return *this;
456   }
457 
458   // Test, Reset, and Set a bit in the bitmap.
test(unsigned Idx)459   bool test(unsigned Idx) const {
460     if (Elements.empty())
461       return false;
462 
463     unsigned ElementIndex = Idx / ElementSize;
464     ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex);
465 
466     // If we can't find an element that is supposed to contain this bit, there
467     // is nothing more to do.
468     if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
469       return false;
470     return ElementIter->test(Idx % ElementSize);
471   }
472 
reset(unsigned Idx)473   void reset(unsigned Idx) {
474     if (Elements.empty())
475       return;
476 
477     unsigned ElementIndex = Idx / ElementSize;
478     ElementListIter ElementIter = FindLowerBound(ElementIndex);
479 
480     // If we can't find an element that is supposed to contain this bit, there
481     // is nothing more to do.
482     if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
483       return;
484     ElementIter->reset(Idx % ElementSize);
485 
486     // When the element is zeroed out, delete it.
487     if (ElementIter->empty()) {
488       ++CurrElementIter;
489       Elements.erase(ElementIter);
490     }
491   }
492 
set(unsigned Idx)493   void set(unsigned Idx) {
494     unsigned ElementIndex = Idx / ElementSize;
495     ElementListIter ElementIter;
496     if (Elements.empty()) {
497       ElementIter = Elements.emplace(Elements.end(), ElementIndex);
498     } else {
499       ElementIter = FindLowerBound(ElementIndex);
500 
501       if (ElementIter == Elements.end() ||
502           ElementIter->index() != ElementIndex) {
503         // We may have hit the beginning of our SparseBitVector, in which case,
504         // we may need to insert right after this element, which requires moving
505         // the current iterator forward one, because insert does insert before.
506         if (ElementIter != Elements.end() &&
507             ElementIter->index() < ElementIndex)
508           ++ElementIter;
509         ElementIter = Elements.emplace(ElementIter, ElementIndex);
510       }
511     }
512     CurrElementIter = ElementIter;
513 
514     ElementIter->set(Idx % ElementSize);
515   }
516 
test_and_set(unsigned Idx)517   bool test_and_set(unsigned Idx) {
518     bool old = test(Idx);
519     if (!old) {
520       set(Idx);
521       return true;
522     }
523     return false;
524   }
525 
526   bool operator!=(const SparseBitVector& RHS) const {
527     return !(*this == RHS);
528   }
529 
530   bool operator==(const SparseBitVector& RHS) const {
531     ElementListConstIter Iter1 = Elements.begin();
532     ElementListConstIter Iter2 = RHS.Elements.begin();
533 
534     for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end();
535          ++Iter1, ++Iter2) {
536       if (*Iter1 != *Iter2)
537         return false;
538     }
539     return Iter1 == Elements.end() && Iter2 == RHS.Elements.end();
540   }
541 
542   // Union our bitmap with the RHS and return true if we changed.
543   bool operator|=(const SparseBitVector& RHS) {
544     if (this == &RHS)
545       return false;
546 
547     if (empty()) {
548       *this = RHS;
549       return true;
550     }
551 
552     bool changed = false;
553     ElementListIter Iter1 = Elements.begin();
554     ElementListConstIter Iter2 = RHS.Elements.begin();
555 
556     // If RHS is empty, we are done
557     if (RHS.Elements.empty())
558       return false;
559 
560     while (Iter2 != RHS.Elements.end()) {
561       if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) {
562         Elements.insert(Iter1, *Iter2);
563         ++Iter2;
564         changed = true;
565       } else if (Iter1->index() == Iter2->index()) {
566         changed |= Iter1->unionWith(*Iter2);
567         ++Iter1;
568         ++Iter2;
569       } else {
570         ++Iter1;
571       }
572     }
573     CurrElementIter = Elements.begin();
574     return changed;
575   }
576 
577   // Intersect our bitmap with the RHS and return true if ours changed.
578   bool operator-=(const SparseBitVector& RHS) {
579     return intersectWithComplement(RHS);
580   }
581 
582   // Intersect our bitmap with the RHS and return true if ours changed.
583   bool operator&=(const SparseBitVector& RHS) {
584     if (this == &RHS)
585       return false;
586 
587     bool changed = false;
588     ElementListIter Iter1 = Elements.begin();
589     ElementListConstIter Iter2 = RHS.Elements.begin();
590 
591     // Check if both bitmaps are empty.
592     if (Elements.empty() && RHS.Elements.empty())
593       return false;
594 
595     // Loop through, intersecting as we go, erasing elements when necessary.
596     while (Iter2 != RHS.Elements.end()) {
597       if (Iter1 == Elements.end()) {
598         CurrElementIter = Elements.begin();
599         return changed;
600       }
601 
602       if (Iter1->index() > Iter2->index()) {
603         ++Iter2;
604       } else if (Iter1->index() == Iter2->index()) {
605         bool BecameZero = false;
606         changed |= Iter1->intersectWith(*Iter2, BecameZero);
607         if (BecameZero) {
608           ElementListIter IterTmp = Iter1;
609           ++Iter1;
610           Elements.erase(IterTmp);
611         } else {
612           ++Iter1;
613         }
614         ++Iter2;
615       } else {
616         ElementListIter IterTmp = Iter1;
617         ++Iter1;
618         Elements.erase(IterTmp);
619         changed = true;
620       }
621     }
622     if (Iter1 != Elements.end()) {
623       Elements.erase(Iter1, Elements.end());
624       changed = true;
625     }
626     CurrElementIter = Elements.begin();
627     return changed;
628   }
629 
630   // Intersect our bitmap with the complement of the RHS and return true
631   // if ours changed.
intersectWithComplement(const SparseBitVector & RHS)632   bool intersectWithComplement(const SparseBitVector& RHS) {
633     if (this == &RHS) {
634       if (!empty()) {
635         clear();
636         return true;
637       }
638       return false;
639     }
640 
641     bool changed = false;
642     ElementListIter Iter1 = Elements.begin();
643     ElementListConstIter Iter2 = RHS.Elements.begin();
644 
645     // If either our bitmap or RHS is empty, we are done
646     if (Elements.empty() || RHS.Elements.empty())
647       return false;
648 
649     // Loop through, intersecting as we go, erasing elements when necessary.
650     while (Iter2 != RHS.Elements.end()) {
651       if (Iter1 == Elements.end()) {
652         CurrElementIter = Elements.begin();
653         return changed;
654       }
655 
656       if (Iter1->index() > Iter2->index()) {
657         ++Iter2;
658       } else if (Iter1->index() == Iter2->index()) {
659         bool BecameZero = false;
660         changed |= Iter1->intersectWithComplement(*Iter2, BecameZero);
661         if (BecameZero) {
662           ElementListIter IterTmp = Iter1;
663           ++Iter1;
664           Elements.erase(IterTmp);
665         } else {
666           ++Iter1;
667         }
668         ++Iter2;
669       } else {
670         ++Iter1;
671       }
672     }
673     CurrElementIter = Elements.begin();
674     return changed;
675   }
676 
intersectWithComplement(const SparseBitVector<ElementSize> * RHS)677   bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const {
678     return intersectWithComplement(*RHS);
679   }
680 
681   //  Three argument version of intersectWithComplement.
682   //  Result of RHS1 & ~RHS2 is stored into this bitmap.
intersectWithComplement(const SparseBitVector<ElementSize> & RHS1,const SparseBitVector<ElementSize> & RHS2)683   void intersectWithComplement(
684       const SparseBitVector<ElementSize>& RHS1,
685       const SparseBitVector<ElementSize>& RHS2) {
686     if (this == &RHS1) {
687       intersectWithComplement(RHS2);
688       return;
689     } else if (this == &RHS2) {
690       SparseBitVector RHS2Copy(RHS2);
691       intersectWithComplement(RHS1, RHS2Copy);
692       return;
693     }
694 
695     Elements.clear();
696     CurrElementIter = Elements.begin();
697     ElementListConstIter Iter1 = RHS1.Elements.begin();
698     ElementListConstIter Iter2 = RHS2.Elements.begin();
699 
700     // If RHS1 is empty, we are done
701     // If RHS2 is empty, we still have to copy RHS1
702     if (RHS1.Elements.empty())
703       return;
704 
705     // Loop through, intersecting as we go, erasing elements when necessary.
706     while (Iter2 != RHS2.Elements.end()) {
707       if (Iter1 == RHS1.Elements.end())
708         return;
709 
710       if (Iter1->index() > Iter2->index()) {
711         ++Iter2;
712       } else if (Iter1->index() == Iter2->index()) {
713         bool BecameZero = false;
714         Elements.emplace_back(Iter1->index());
715         Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero);
716         if (BecameZero)
717           Elements.pop_back();
718         ++Iter1;
719         ++Iter2;
720       } else {
721         Elements.push_back(*Iter1++);
722       }
723     }
724 
725     // copy the remaining elements
726     std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements));
727   }
728 
intersectWithComplement(const SparseBitVector<ElementSize> * RHS1,const SparseBitVector<ElementSize> * RHS2)729   void intersectWithComplement(
730       const SparseBitVector<ElementSize>* RHS1,
731       const SparseBitVector<ElementSize>* RHS2) {
732     intersectWithComplement(*RHS1, *RHS2);
733   }
734 
intersects(const SparseBitVector<ElementSize> * RHS)735   bool intersects(const SparseBitVector<ElementSize>* RHS) const {
736     return intersects(*RHS);
737   }
738 
739   // Return true if we share any bits in common with RHS
intersects(const SparseBitVector<ElementSize> & RHS)740   bool intersects(const SparseBitVector<ElementSize>& RHS) const {
741     ElementListConstIter Iter1 = Elements.begin();
742     ElementListConstIter Iter2 = RHS.Elements.begin();
743 
744     // Check if both bitmaps are empty.
745     if (Elements.empty() && RHS.Elements.empty())
746       return false;
747 
748     // Loop through, intersecting stopping when we hit bits in common.
749     while (Iter2 != RHS.Elements.end()) {
750       if (Iter1 == Elements.end())
751         return false;
752 
753       if (Iter1->index() > Iter2->index()) {
754         ++Iter2;
755       } else if (Iter1->index() == Iter2->index()) {
756         if (Iter1->intersects(*Iter2))
757           return true;
758         ++Iter1;
759         ++Iter2;
760       } else {
761         ++Iter1;
762       }
763     }
764     return false;
765   }
766 
767   // Return true iff all bits set in this SparseBitVector are
768   // also set in RHS.
contains(const SparseBitVector<ElementSize> & RHS)769   bool contains(const SparseBitVector<ElementSize>& RHS) const {
770     SparseBitVector<ElementSize> Result(*this);
771     Result &= RHS;
772     return (Result == RHS);
773   }
774 
775   // Return the first set bit in the bitmap.  Return -1 if no bits are set.
find_first()776   int find_first() const {
777     if (Elements.empty())
778       return -1;
779     const SparseBitVectorElement<ElementSize>& First = *(Elements.begin());
780     return (First.index() * ElementSize) + First.find_first();
781   }
782 
783   // Return the last set bit in the bitmap.  Return -1 if no bits are set.
find_last()784   int find_last() const {
785     if (Elements.empty())
786       return -1;
787     const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin());
788     return (Last.index() * ElementSize) + Last.find_last();
789   }
790 
791   // Return true if the SparseBitVector is empty
empty()792   bool empty() const {
793     return Elements.empty();
794   }
795 
count()796   unsigned count() const {
797     unsigned BitCount = 0;
798     for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end();
799          ++Iter)
800       BitCount += Iter->count();
801 
802     return BitCount;
803   }
804 
begin()805   iterator begin() const {
806     return iterator(this);
807   }
808 
end()809   iterator end() const {
810     return iterator(this, true);
811   }
812 };
813 
814 // Convenience functions to allow Or and And without dereferencing in the user
815 // code.
816 
817 template <unsigned ElementSize>
818 inline bool operator|=(
819     SparseBitVector<ElementSize>& LHS,
820     const SparseBitVector<ElementSize>* RHS) {
821   return LHS |= *RHS;
822 }
823 
824 template <unsigned ElementSize>
825 inline bool operator|=(
826     SparseBitVector<ElementSize>* LHS,
827     const SparseBitVector<ElementSize>& RHS) {
828   return LHS->operator|=(RHS);
829 }
830 
831 template <unsigned ElementSize>
832 inline bool operator&=(
833     SparseBitVector<ElementSize>* LHS,
834     const SparseBitVector<ElementSize>& RHS) {
835   return LHS->operator&=(RHS);
836 }
837 
838 template <unsigned ElementSize>
839 inline bool operator&=(
840     SparseBitVector<ElementSize>& LHS,
841     const SparseBitVector<ElementSize>* RHS) {
842   return LHS &= *RHS;
843 }
844 
845 // Convenience functions for infix union, intersection, difference operators.
846 
847 template <unsigned ElementSize>
848 inline SparseBitVector<ElementSize> operator|(
849     const SparseBitVector<ElementSize>& LHS,
850     const SparseBitVector<ElementSize>& RHS) {
851   SparseBitVector<ElementSize> Result(LHS);
852   Result |= RHS;
853   return Result;
854 }
855 
856 template <unsigned ElementSize>
857 inline SparseBitVector<ElementSize> operator&(
858     const SparseBitVector<ElementSize>& LHS,
859     const SparseBitVector<ElementSize>& RHS) {
860   SparseBitVector<ElementSize> Result(LHS);
861   Result &= RHS;
862   return Result;
863 }
864 
865 template <unsigned ElementSize>
866 inline SparseBitVector<ElementSize> operator-(
867     const SparseBitVector<ElementSize>& LHS,
868     const SparseBitVector<ElementSize>& RHS) {
869   SparseBitVector<ElementSize> Result;
870   Result.intersectWithComplement(LHS, RHS);
871   return Result;
872 }
873 
874 template <unsigned ElementSize>
875 std::ostream& operator<<(
876     std::ostream& stream,
877     const SparseBitVector<ElementSize>& vec) {
878   bool first = true;
879   stream << "{";
880   for (auto el : vec) {
881     if (first) {
882       first = false;
883     } else {
884       stream << ", ";
885     }
886     stream << el;
887   }
888   stream << "}";
889   return stream;
890 }
891 
892 } // end namespace c10
893