1 //===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SparseBitVector class. See the doxygen comment for 10 // SparseBitVector for more details on the algorithm used. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #pragma once 15 #include <c10/macros/Macros.h> 16 #include <c10/util/llvmMathExtras.h> 17 #include <array> 18 #include <cassert> 19 #include <climits> 20 #include <iterator> 21 #include <list> 22 #include <ostream> 23 24 namespace c10 { 25 26 /// SparseBitVector is an implementation of a bitvector that is sparse by only 27 /// storing the elements that have non-zero bits set. In order to make this 28 /// fast for the most common cases, SparseBitVector is implemented as a linked 29 /// list of SparseBitVectorElements. We maintain a pointer to the last 30 /// SparseBitVectorElement accessed (in the form of a list iterator), in order 31 /// to make multiple in-order test/set constant time after the first one is 32 /// executed. Note that using vectors to store SparseBitVectorElement's does 33 /// not work out very well because it causes insertion in the middle to take 34 /// enormous amounts of time with a large amount of bits. Other structures that 35 /// have better worst cases for insertion in the middle (various balanced trees, 36 /// etc) do not perform as well in practice as a linked list with this iterator 37 /// kept up to date. They are also significantly more memory intensive. 38 39 template <unsigned ElementSize = 128> 40 struct SparseBitVectorElement { 41 public: 42 using BitWord = unsigned long; 43 using size_type = unsigned; 44 enum { 45 BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT, 46 BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE, 47 BITS_PER_ELEMENT = ElementSize 48 }; 49 50 private: 51 // Index of Element in terms of where first bit starts. 52 unsigned ElementIndex; 53 std::array<BitWord, BITWORDS_PER_ELEMENT> Bits{}; 54 SparseBitVectorElementSparseBitVectorElement55 SparseBitVectorElement() : ElementIndex(~0U) {} 56 57 public: SparseBitVectorElementSparseBitVectorElement58 explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {} 59 60 // Comparison. 61 bool operator==(const SparseBitVectorElement& RHS) const { 62 if (ElementIndex != RHS.ElementIndex) 63 return false; 64 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) 65 if (Bits[i] != RHS.Bits[i]) 66 return false; 67 return true; 68 } 69 70 bool operator!=(const SparseBitVectorElement& RHS) const { 71 return !(*this == RHS); 72 } 73 74 // Return the bits that make up word Idx in our element. wordSparseBitVectorElement75 BitWord word(unsigned Idx) const { 76 assert(Idx < BITWORDS_PER_ELEMENT); 77 return Bits[Idx]; 78 } 79 indexSparseBitVectorElement80 unsigned index() const { 81 return ElementIndex; 82 } 83 emptySparseBitVectorElement84 bool empty() const { 85 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) 86 if (Bits[i]) 87 return false; 88 return true; 89 } 90 setSparseBitVectorElement91 void set(unsigned Idx) { 92 Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE); 93 } 94 test_and_setSparseBitVectorElement95 bool test_and_set(unsigned Idx) { 96 bool old = test(Idx); 97 if (!old) { 98 set(Idx); 99 return true; 100 } 101 return false; 102 } 103 resetSparseBitVectorElement104 void reset(unsigned Idx) { 105 Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE)); 106 } 107 testSparseBitVectorElement108 bool test(unsigned Idx) const { 109 return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE)); 110 } 111 countSparseBitVectorElement112 size_type count() const { 113 unsigned NumBits = 0; 114 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) 115 NumBits += llvm::countPopulation(Bits[i]); 116 return NumBits; 117 } 118 119 /// find_first - Returns the index of the first set bit. find_firstSparseBitVectorElement120 int find_first() const { 121 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) 122 if (Bits[i] != 0) 123 return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); 124 throw std::runtime_error("Illegal empty element"); 125 } 126 127 /// find_last - Returns the index of the last set bit. find_lastSparseBitVectorElement128 int find_last() const { 129 for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) { 130 unsigned Idx = BITWORDS_PER_ELEMENT - I - 1; 131 if (Bits[Idx] != 0) 132 return Idx * BITWORD_SIZE + BITWORD_SIZE - 133 llvm::countLeadingZeros(Bits[Idx]); 134 } 135 throw std::runtime_error("Illegal empty element"); 136 } 137 138 /// find_next - Returns the index of the next set bit starting from the 139 /// "Curr" bit. Returns -1 if the next set bit is not found. find_nextSparseBitVectorElement140 int find_next(unsigned Curr) const { 141 if (Curr >= BITS_PER_ELEMENT) 142 return -1; 143 144 unsigned WordPos = Curr / BITWORD_SIZE; 145 unsigned BitPos = Curr % BITWORD_SIZE; 146 BitWord Copy = Bits[WordPos]; 147 assert( 148 WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element"); 149 150 // Mask off previous bits. 151 Copy &= ~0UL << BitPos; 152 153 if (Copy != 0) 154 return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy); 155 156 // Check subsequent words. 157 for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i) 158 if (Bits[i] != 0) 159 return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); 160 return -1; 161 } 162 163 // Union this element with RHS and return true if this one changed. unionWithSparseBitVectorElement164 bool unionWith(const SparseBitVectorElement& RHS) { 165 bool changed = false; 166 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { 167 BitWord old = changed ? 0 : Bits[i]; 168 169 Bits[i] |= RHS.Bits[i]; 170 if (!changed && old != Bits[i]) 171 changed = true; 172 } 173 return changed; 174 } 175 176 // Return true if we have any bits in common with RHS intersectsSparseBitVectorElement177 bool intersects(const SparseBitVectorElement& RHS) const { 178 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { 179 if (RHS.Bits[i] & Bits[i]) 180 return true; 181 } 182 return false; 183 } 184 185 // Intersect this Element with RHS and return true if this one changed. 186 // BecameZero is set to true if this element became all-zero bits. intersectWithSparseBitVectorElement187 bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) { 188 bool changed = false; 189 bool allzero = true; 190 191 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { 192 BitWord old = changed ? 0 : Bits[i]; 193 194 Bits[i] &= RHS.Bits[i]; 195 if (Bits[i] != 0) 196 allzero = false; 197 198 if (!changed && old != Bits[i]) 199 changed = true; 200 } 201 BecameZero = allzero; 202 return changed; 203 } 204 205 // Intersect this Element with the complement of RHS and return true if this 206 // one changed. BecameZero is set to true if this element became all-zero 207 // bits. intersectWithComplementSparseBitVectorElement208 bool intersectWithComplement( 209 const SparseBitVectorElement& RHS, 210 bool& BecameZero) { 211 bool changed = false; 212 bool allzero = true; 213 214 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { 215 BitWord old = changed ? 0 : Bits[i]; 216 217 Bits[i] &= ~RHS.Bits[i]; 218 if (Bits[i] != 0) 219 allzero = false; 220 221 if (!changed && old != Bits[i]) 222 changed = true; 223 } 224 BecameZero = allzero; 225 return changed; 226 } 227 228 // Three argument version of intersectWithComplement that intersects 229 // RHS1 & ~RHS2 into this element intersectWithComplementSparseBitVectorElement230 void intersectWithComplement( 231 const SparseBitVectorElement& RHS1, 232 const SparseBitVectorElement& RHS2, 233 bool& BecameZero) { 234 bool allzero = true; 235 236 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { 237 Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i]; 238 if (Bits[i] != 0) 239 allzero = false; 240 } 241 BecameZero = allzero; 242 } 243 }; 244 245 template <unsigned ElementSize = 128> 246 class SparseBitVector { 247 using ElementList = std::list<SparseBitVectorElement<ElementSize>>; 248 using ElementListIter = typename ElementList::iterator; 249 using ElementListConstIter = typename ElementList::const_iterator; 250 enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE }; 251 252 ElementList Elements; 253 // Pointer to our current Element. This has no visible effect on the external 254 // state of a SparseBitVector, it's just used to improve performance in the 255 // common case of testing/modifying bits with similar indices. 256 mutable ElementListIter CurrElementIter; 257 258 // This is like std::lower_bound, except we do linear searching from the 259 // current position. FindLowerBoundImpl(unsigned ElementIndex)260 ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const { 261 // We cache a non-const iterator so we're forced to resort to const_cast to 262 // get the begin/end in the case where 'this' is const. To avoid duplication 263 // of code with the only difference being whether the const cast is present 264 // 'this' is always const in this particular function and we sort out the 265 // difference in FindLowerBound and FindLowerBoundConst. 266 ElementListIter Begin = 267 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) 268 const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin(); 269 ElementListIter End = 270 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) 271 const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end(); 272 273 if (Elements.empty()) { 274 CurrElementIter = Begin; 275 return CurrElementIter; 276 } 277 278 // Make sure our current iterator is valid. 279 if (CurrElementIter == End) 280 --CurrElementIter; 281 282 // Search from our current iterator, either backwards or forwards, 283 // depending on what element we are looking for. 284 ElementListIter ElementIter = CurrElementIter; 285 if (CurrElementIter->index() == ElementIndex) { 286 return ElementIter; 287 } else if (CurrElementIter->index() > ElementIndex) { 288 while (ElementIter != Begin && ElementIter->index() > ElementIndex) 289 --ElementIter; 290 } else { 291 while (ElementIter != End && ElementIter->index() < ElementIndex) 292 ++ElementIter; 293 } 294 CurrElementIter = ElementIter; 295 return ElementIter; 296 } FindLowerBoundConst(unsigned ElementIndex)297 ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const { 298 return FindLowerBoundImpl(ElementIndex); 299 } FindLowerBound(unsigned ElementIndex)300 ElementListIter FindLowerBound(unsigned ElementIndex) { 301 return FindLowerBoundImpl(ElementIndex); 302 } 303 304 // Iterator to walk set bits in the bitmap. This iterator is a lot uglier 305 // than it would be, in order to be efficient. 306 class SparseBitVectorIterator { 307 private: 308 bool AtEnd{false}; 309 310 const SparseBitVector<ElementSize>* BitVector = nullptr; 311 312 // Current element inside of bitmap. 313 ElementListConstIter Iter; 314 315 // Current bit number inside of our bitmap. 316 unsigned BitNumber{0}; 317 318 // Current word number inside of our element. 319 unsigned WordNumber{0}; 320 321 // Current bits from the element. 322 typename SparseBitVectorElement<ElementSize>::BitWord Bits{0}; 323 324 // Move our iterator to the first non-zero bit in the bitmap. AdvanceToFirstNonZero()325 void AdvanceToFirstNonZero() { 326 if (AtEnd) 327 return; 328 if (BitVector->Elements.empty()) { 329 AtEnd = true; 330 return; 331 } 332 Iter = BitVector->Elements.begin(); 333 BitNumber = Iter->index() * ElementSize; 334 unsigned BitPos = Iter->find_first(); 335 BitNumber += BitPos; 336 WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; 337 Bits = Iter->word(WordNumber); 338 Bits >>= BitPos % BITWORD_SIZE; 339 } 340 341 // Move our iterator to the next non-zero bit. AdvanceToNextNonZero()342 void AdvanceToNextNonZero() { 343 if (AtEnd) 344 return; 345 346 while (Bits && !(Bits & 1)) { 347 Bits >>= 1; 348 BitNumber += 1; 349 } 350 351 // See if we ran out of Bits in this word. 352 if (!Bits) { 353 int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize); 354 // If we ran out of set bits in this element, move to next element. 355 if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) { 356 ++Iter; 357 WordNumber = 0; 358 359 // We may run out of elements in the bitmap. 360 if (Iter == BitVector->Elements.end()) { 361 AtEnd = true; 362 return; 363 } 364 // Set up for next non-zero word in bitmap. 365 BitNumber = Iter->index() * ElementSize; 366 NextSetBitNumber = Iter->find_first(); 367 BitNumber += NextSetBitNumber; 368 WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; 369 Bits = Iter->word(WordNumber); 370 Bits >>= NextSetBitNumber % BITWORD_SIZE; 371 } else { 372 WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE; 373 Bits = Iter->word(WordNumber); 374 Bits >>= NextSetBitNumber % BITWORD_SIZE; 375 BitNumber = Iter->index() * ElementSize; 376 BitNumber += NextSetBitNumber; 377 } 378 } 379 } 380 381 public: 382 SparseBitVectorIterator() = default; 383 384 SparseBitVectorIterator( 385 const SparseBitVector<ElementSize>* RHS, 386 bool end = false) AtEnd(end)387 : AtEnd(end), 388 BitVector(RHS), 389 Iter(BitVector->Elements.begin()), 390 WordNumber(~0) { 391 AdvanceToFirstNonZero(); 392 } 393 394 // Preincrement. 395 inline SparseBitVectorIterator& operator++() { 396 ++BitNumber; 397 Bits >>= 1; 398 AdvanceToNextNonZero(); 399 return *this; 400 } 401 402 // Postincrement. 403 inline SparseBitVectorIterator operator++(int) { 404 SparseBitVectorIterator tmp = *this; 405 ++*this; 406 return tmp; 407 } 408 409 // Return the current set bit number. 410 unsigned operator*() const { 411 return BitNumber; 412 } 413 414 bool operator==(const SparseBitVectorIterator& RHS) const { 415 // If they are both at the end, ignore the rest of the fields. 416 if (AtEnd && RHS.AtEnd) 417 return true; 418 // Otherwise they are the same if they have the same bit number and 419 // bitmap. 420 return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber; 421 } 422 423 bool operator!=(const SparseBitVectorIterator& RHS) const { 424 return !(*this == RHS); 425 } 426 }; 427 428 public: 429 using iterator = SparseBitVectorIterator; 430 SparseBitVector()431 SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {} 432 SparseBitVector(const SparseBitVector & RHS)433 SparseBitVector(const SparseBitVector& RHS) 434 : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {} SparseBitVector(SparseBitVector && RHS)435 SparseBitVector(SparseBitVector&& RHS) noexcept 436 : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {} 437 438 // Clear. clear()439 void clear() { 440 Elements.clear(); 441 } 442 443 // Assignment 444 SparseBitVector& operator=(const SparseBitVector& RHS) { 445 if (this == &RHS) 446 return *this; 447 448 Elements = RHS.Elements; 449 CurrElementIter = Elements.begin(); 450 return *this; 451 } 452 SparseBitVector& operator=(SparseBitVector&& RHS) noexcept { 453 Elements = std::move(RHS.Elements); 454 CurrElementIter = Elements.begin(); 455 return *this; 456 } 457 458 // Test, Reset, and Set a bit in the bitmap. test(unsigned Idx)459 bool test(unsigned Idx) const { 460 if (Elements.empty()) 461 return false; 462 463 unsigned ElementIndex = Idx / ElementSize; 464 ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex); 465 466 // If we can't find an element that is supposed to contain this bit, there 467 // is nothing more to do. 468 if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) 469 return false; 470 return ElementIter->test(Idx % ElementSize); 471 } 472 reset(unsigned Idx)473 void reset(unsigned Idx) { 474 if (Elements.empty()) 475 return; 476 477 unsigned ElementIndex = Idx / ElementSize; 478 ElementListIter ElementIter = FindLowerBound(ElementIndex); 479 480 // If we can't find an element that is supposed to contain this bit, there 481 // is nothing more to do. 482 if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) 483 return; 484 ElementIter->reset(Idx % ElementSize); 485 486 // When the element is zeroed out, delete it. 487 if (ElementIter->empty()) { 488 ++CurrElementIter; 489 Elements.erase(ElementIter); 490 } 491 } 492 set(unsigned Idx)493 void set(unsigned Idx) { 494 unsigned ElementIndex = Idx / ElementSize; 495 ElementListIter ElementIter; 496 if (Elements.empty()) { 497 ElementIter = Elements.emplace(Elements.end(), ElementIndex); 498 } else { 499 ElementIter = FindLowerBound(ElementIndex); 500 501 if (ElementIter == Elements.end() || 502 ElementIter->index() != ElementIndex) { 503 // We may have hit the beginning of our SparseBitVector, in which case, 504 // we may need to insert right after this element, which requires moving 505 // the current iterator forward one, because insert does insert before. 506 if (ElementIter != Elements.end() && 507 ElementIter->index() < ElementIndex) 508 ++ElementIter; 509 ElementIter = Elements.emplace(ElementIter, ElementIndex); 510 } 511 } 512 CurrElementIter = ElementIter; 513 514 ElementIter->set(Idx % ElementSize); 515 } 516 test_and_set(unsigned Idx)517 bool test_and_set(unsigned Idx) { 518 bool old = test(Idx); 519 if (!old) { 520 set(Idx); 521 return true; 522 } 523 return false; 524 } 525 526 bool operator!=(const SparseBitVector& RHS) const { 527 return !(*this == RHS); 528 } 529 530 bool operator==(const SparseBitVector& RHS) const { 531 ElementListConstIter Iter1 = Elements.begin(); 532 ElementListConstIter Iter2 = RHS.Elements.begin(); 533 534 for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end(); 535 ++Iter1, ++Iter2) { 536 if (*Iter1 != *Iter2) 537 return false; 538 } 539 return Iter1 == Elements.end() && Iter2 == RHS.Elements.end(); 540 } 541 542 // Union our bitmap with the RHS and return true if we changed. 543 bool operator|=(const SparseBitVector& RHS) { 544 if (this == &RHS) 545 return false; 546 547 if (empty()) { 548 *this = RHS; 549 return true; 550 } 551 552 bool changed = false; 553 ElementListIter Iter1 = Elements.begin(); 554 ElementListConstIter Iter2 = RHS.Elements.begin(); 555 556 // If RHS is empty, we are done 557 if (RHS.Elements.empty()) 558 return false; 559 560 while (Iter2 != RHS.Elements.end()) { 561 if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) { 562 Elements.insert(Iter1, *Iter2); 563 ++Iter2; 564 changed = true; 565 } else if (Iter1->index() == Iter2->index()) { 566 changed |= Iter1->unionWith(*Iter2); 567 ++Iter1; 568 ++Iter2; 569 } else { 570 ++Iter1; 571 } 572 } 573 CurrElementIter = Elements.begin(); 574 return changed; 575 } 576 577 // Intersect our bitmap with the RHS and return true if ours changed. 578 bool operator-=(const SparseBitVector& RHS) { 579 return intersectWithComplement(RHS); 580 } 581 582 // Intersect our bitmap with the RHS and return true if ours changed. 583 bool operator&=(const SparseBitVector& RHS) { 584 if (this == &RHS) 585 return false; 586 587 bool changed = false; 588 ElementListIter Iter1 = Elements.begin(); 589 ElementListConstIter Iter2 = RHS.Elements.begin(); 590 591 // Check if both bitmaps are empty. 592 if (Elements.empty() && RHS.Elements.empty()) 593 return false; 594 595 // Loop through, intersecting as we go, erasing elements when necessary. 596 while (Iter2 != RHS.Elements.end()) { 597 if (Iter1 == Elements.end()) { 598 CurrElementIter = Elements.begin(); 599 return changed; 600 } 601 602 if (Iter1->index() > Iter2->index()) { 603 ++Iter2; 604 } else if (Iter1->index() == Iter2->index()) { 605 bool BecameZero = false; 606 changed |= Iter1->intersectWith(*Iter2, BecameZero); 607 if (BecameZero) { 608 ElementListIter IterTmp = Iter1; 609 ++Iter1; 610 Elements.erase(IterTmp); 611 } else { 612 ++Iter1; 613 } 614 ++Iter2; 615 } else { 616 ElementListIter IterTmp = Iter1; 617 ++Iter1; 618 Elements.erase(IterTmp); 619 changed = true; 620 } 621 } 622 if (Iter1 != Elements.end()) { 623 Elements.erase(Iter1, Elements.end()); 624 changed = true; 625 } 626 CurrElementIter = Elements.begin(); 627 return changed; 628 } 629 630 // Intersect our bitmap with the complement of the RHS and return true 631 // if ours changed. intersectWithComplement(const SparseBitVector & RHS)632 bool intersectWithComplement(const SparseBitVector& RHS) { 633 if (this == &RHS) { 634 if (!empty()) { 635 clear(); 636 return true; 637 } 638 return false; 639 } 640 641 bool changed = false; 642 ElementListIter Iter1 = Elements.begin(); 643 ElementListConstIter Iter2 = RHS.Elements.begin(); 644 645 // If either our bitmap or RHS is empty, we are done 646 if (Elements.empty() || RHS.Elements.empty()) 647 return false; 648 649 // Loop through, intersecting as we go, erasing elements when necessary. 650 while (Iter2 != RHS.Elements.end()) { 651 if (Iter1 == Elements.end()) { 652 CurrElementIter = Elements.begin(); 653 return changed; 654 } 655 656 if (Iter1->index() > Iter2->index()) { 657 ++Iter2; 658 } else if (Iter1->index() == Iter2->index()) { 659 bool BecameZero = false; 660 changed |= Iter1->intersectWithComplement(*Iter2, BecameZero); 661 if (BecameZero) { 662 ElementListIter IterTmp = Iter1; 663 ++Iter1; 664 Elements.erase(IterTmp); 665 } else { 666 ++Iter1; 667 } 668 ++Iter2; 669 } else { 670 ++Iter1; 671 } 672 } 673 CurrElementIter = Elements.begin(); 674 return changed; 675 } 676 intersectWithComplement(const SparseBitVector<ElementSize> * RHS)677 bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const { 678 return intersectWithComplement(*RHS); 679 } 680 681 // Three argument version of intersectWithComplement. 682 // Result of RHS1 & ~RHS2 is stored into this bitmap. intersectWithComplement(const SparseBitVector<ElementSize> & RHS1,const SparseBitVector<ElementSize> & RHS2)683 void intersectWithComplement( 684 const SparseBitVector<ElementSize>& RHS1, 685 const SparseBitVector<ElementSize>& RHS2) { 686 if (this == &RHS1) { 687 intersectWithComplement(RHS2); 688 return; 689 } else if (this == &RHS2) { 690 SparseBitVector RHS2Copy(RHS2); 691 intersectWithComplement(RHS1, RHS2Copy); 692 return; 693 } 694 695 Elements.clear(); 696 CurrElementIter = Elements.begin(); 697 ElementListConstIter Iter1 = RHS1.Elements.begin(); 698 ElementListConstIter Iter2 = RHS2.Elements.begin(); 699 700 // If RHS1 is empty, we are done 701 // If RHS2 is empty, we still have to copy RHS1 702 if (RHS1.Elements.empty()) 703 return; 704 705 // Loop through, intersecting as we go, erasing elements when necessary. 706 while (Iter2 != RHS2.Elements.end()) { 707 if (Iter1 == RHS1.Elements.end()) 708 return; 709 710 if (Iter1->index() > Iter2->index()) { 711 ++Iter2; 712 } else if (Iter1->index() == Iter2->index()) { 713 bool BecameZero = false; 714 Elements.emplace_back(Iter1->index()); 715 Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero); 716 if (BecameZero) 717 Elements.pop_back(); 718 ++Iter1; 719 ++Iter2; 720 } else { 721 Elements.push_back(*Iter1++); 722 } 723 } 724 725 // copy the remaining elements 726 std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements)); 727 } 728 intersectWithComplement(const SparseBitVector<ElementSize> * RHS1,const SparseBitVector<ElementSize> * RHS2)729 void intersectWithComplement( 730 const SparseBitVector<ElementSize>* RHS1, 731 const SparseBitVector<ElementSize>* RHS2) { 732 intersectWithComplement(*RHS1, *RHS2); 733 } 734 intersects(const SparseBitVector<ElementSize> * RHS)735 bool intersects(const SparseBitVector<ElementSize>* RHS) const { 736 return intersects(*RHS); 737 } 738 739 // Return true if we share any bits in common with RHS intersects(const SparseBitVector<ElementSize> & RHS)740 bool intersects(const SparseBitVector<ElementSize>& RHS) const { 741 ElementListConstIter Iter1 = Elements.begin(); 742 ElementListConstIter Iter2 = RHS.Elements.begin(); 743 744 // Check if both bitmaps are empty. 745 if (Elements.empty() && RHS.Elements.empty()) 746 return false; 747 748 // Loop through, intersecting stopping when we hit bits in common. 749 while (Iter2 != RHS.Elements.end()) { 750 if (Iter1 == Elements.end()) 751 return false; 752 753 if (Iter1->index() > Iter2->index()) { 754 ++Iter2; 755 } else if (Iter1->index() == Iter2->index()) { 756 if (Iter1->intersects(*Iter2)) 757 return true; 758 ++Iter1; 759 ++Iter2; 760 } else { 761 ++Iter1; 762 } 763 } 764 return false; 765 } 766 767 // Return true iff all bits set in this SparseBitVector are 768 // also set in RHS. contains(const SparseBitVector<ElementSize> & RHS)769 bool contains(const SparseBitVector<ElementSize>& RHS) const { 770 SparseBitVector<ElementSize> Result(*this); 771 Result &= RHS; 772 return (Result == RHS); 773 } 774 775 // Return the first set bit in the bitmap. Return -1 if no bits are set. find_first()776 int find_first() const { 777 if (Elements.empty()) 778 return -1; 779 const SparseBitVectorElement<ElementSize>& First = *(Elements.begin()); 780 return (First.index() * ElementSize) + First.find_first(); 781 } 782 783 // Return the last set bit in the bitmap. Return -1 if no bits are set. find_last()784 int find_last() const { 785 if (Elements.empty()) 786 return -1; 787 const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin()); 788 return (Last.index() * ElementSize) + Last.find_last(); 789 } 790 791 // Return true if the SparseBitVector is empty empty()792 bool empty() const { 793 return Elements.empty(); 794 } 795 count()796 unsigned count() const { 797 unsigned BitCount = 0; 798 for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end(); 799 ++Iter) 800 BitCount += Iter->count(); 801 802 return BitCount; 803 } 804 begin()805 iterator begin() const { 806 return iterator(this); 807 } 808 end()809 iterator end() const { 810 return iterator(this, true); 811 } 812 }; 813 814 // Convenience functions to allow Or and And without dereferencing in the user 815 // code. 816 817 template <unsigned ElementSize> 818 inline bool operator|=( 819 SparseBitVector<ElementSize>& LHS, 820 const SparseBitVector<ElementSize>* RHS) { 821 return LHS |= *RHS; 822 } 823 824 template <unsigned ElementSize> 825 inline bool operator|=( 826 SparseBitVector<ElementSize>* LHS, 827 const SparseBitVector<ElementSize>& RHS) { 828 return LHS->operator|=(RHS); 829 } 830 831 template <unsigned ElementSize> 832 inline bool operator&=( 833 SparseBitVector<ElementSize>* LHS, 834 const SparseBitVector<ElementSize>& RHS) { 835 return LHS->operator&=(RHS); 836 } 837 838 template <unsigned ElementSize> 839 inline bool operator&=( 840 SparseBitVector<ElementSize>& LHS, 841 const SparseBitVector<ElementSize>* RHS) { 842 return LHS &= *RHS; 843 } 844 845 // Convenience functions for infix union, intersection, difference operators. 846 847 template <unsigned ElementSize> 848 inline SparseBitVector<ElementSize> operator|( 849 const SparseBitVector<ElementSize>& LHS, 850 const SparseBitVector<ElementSize>& RHS) { 851 SparseBitVector<ElementSize> Result(LHS); 852 Result |= RHS; 853 return Result; 854 } 855 856 template <unsigned ElementSize> 857 inline SparseBitVector<ElementSize> operator&( 858 const SparseBitVector<ElementSize>& LHS, 859 const SparseBitVector<ElementSize>& RHS) { 860 SparseBitVector<ElementSize> Result(LHS); 861 Result &= RHS; 862 return Result; 863 } 864 865 template <unsigned ElementSize> 866 inline SparseBitVector<ElementSize> operator-( 867 const SparseBitVector<ElementSize>& LHS, 868 const SparseBitVector<ElementSize>& RHS) { 869 SparseBitVector<ElementSize> Result; 870 Result.intersectWithComplement(LHS, RHS); 871 return Result; 872 } 873 874 template <unsigned ElementSize> 875 std::ostream& operator<<( 876 std::ostream& stream, 877 const SparseBitVector<ElementSize>& vec) { 878 bool first = true; 879 stream << "{"; 880 for (auto el : vec) { 881 if (first) { 882 first = false; 883 } else { 884 stream << ", "; 885 } 886 stream << el; 887 } 888 stream << "}"; 889 return stream; 890 } 891 892 } // end namespace c10 893