1 #pragma once 2 3 #include <torch/csrc/utils/variadic.h> 4 #include <torch/types.h> 5 6 #include <c10/util/Exception.h> 7 8 #include <functional> 9 #include <iterator> 10 #include <memory> 11 #include <type_traits> 12 #include <utility> 13 14 namespace torch { 15 namespace data { 16 namespace detail { 17 // For increased safety and more separated logic, this implementation of 18 // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A 19 // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While 20 // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are 21 // the same object. When the `ValidIterator` becomes exhausted, it compares 22 // equal to the `SentinelIterator`, but not before. Half the code here is to 23 // implement double dispatch for the comparison. Got damnit, C++. 24 25 template <typename Batch> 26 struct ValidIterator; 27 28 template <typename Batch> 29 struct SentinelIterator; 30 31 /// Base class for the `ValidIterator` and `SentinelIterator` 32 template <typename Batch> 33 struct IteratorImpl { 34 virtual ~IteratorImpl() = default; 35 virtual void next() = 0; 36 virtual Batch& get() = 0; 37 virtual bool operator==(const IteratorImpl& other) const = 0; 38 virtual bool operator==(const ValidIterator<Batch>& other) const = 0; 39 virtual bool operator==(const SentinelIterator<Batch>& other) const = 0; 40 }; 41 42 template <typename Batch> 43 struct ValidIterator : public IteratorImpl<Batch> { 44 using BatchProducer = std::function<std::optional<Batch>()>; 45 ValidIteratorValidIterator46 explicit ValidIterator(BatchProducer next_batch) 47 : next_batch_(std::move(next_batch)) {} 48 49 /// Fetches the next batch. nextValidIterator50 void next() override { 51 // If we didn't get the very first batch yet, get it now. 52 lazy_initialize(); 53 TORCH_CHECK( 54 batch_.has_value(), "Attempted to increment iterator past the end"); 55 // Increment to the next batch. 56 batch_ = next_batch_(); 57 } 58 59 /// Returns the current batch. The precondition for this operation to not 60 /// throw an exception is that it has been compared to the `SentinelIterator` 61 /// and did not compare equal. getValidIterator62 Batch& get() override { 63 // If we didn't get the very first batch yet, get it now. 64 lazy_initialize(); 65 TORCH_CHECK( 66 batch_.has_value(), 67 "Attempted to dereference iterator that was past the end"); 68 return batch_.value(); 69 } 70 71 /// Does double dispatch. 72 bool operator==(const IteratorImpl<Batch>& other) const override { 73 return other == *this; 74 } 75 76 /// A `ValidIterator` is equal to the `SentinelIterator` iff. the 77 /// `ValidIterator` has reached the end of the dataloader. 78 bool operator==(const SentinelIterator<Batch>& /* unused */) const override { 79 lazy_initialize(); 80 return !batch_; 81 } 82 83 /// Returns true if the memory address of `other` equals that of `this`. 84 bool operator==(const ValidIterator<Batch>& other) const override { 85 return &other == this; 86 } 87 88 /// Gets the very first batch if it has not yet been fetched. lazy_initializeValidIterator89 void lazy_initialize() const { 90 if (!initialized_) { 91 batch_ = next_batch_(); 92 initialized_ = true; 93 } 94 } 95 96 BatchProducer next_batch_; 97 mutable std::optional<Batch> batch_; 98 mutable bool initialized_ = false; 99 }; 100 101 template <typename Batch> 102 struct SentinelIterator : public IteratorImpl<Batch> { nextSentinelIterator103 void next() override { 104 AT_ERROR( 105 "Incrementing the DataLoader's past-the-end iterator is not allowed"); 106 } 107 getSentinelIterator108 Batch& get() override { 109 AT_ERROR( 110 "Dereferencing the DataLoader's past-the-end iterator is not allowed"); 111 } 112 113 /// Does double dispatch. 114 bool operator==(const IteratorImpl<Batch>& other) const override { 115 return other == *this; 116 } 117 118 /// Calls the comparison operator between `ValidIterator` and 119 /// `SentinelIterator`. 120 bool operator==(const ValidIterator<Batch>& other) const override { 121 return other == *this; 122 } 123 124 /// Sentinel iterators always compare equal. 125 bool operator==(const SentinelIterator<Batch>& other) const override { 126 return true; 127 } 128 }; 129 } // namespace detail 130 131 template <typename Batch> 132 class Iterator { 133 public: 134 // Type aliases to make the class recognized as a proper iterator. 135 using difference_type = std::ptrdiff_t; 136 using value_type = Batch; 137 using pointer = Batch*; 138 using reference = Batch&; 139 using iterator_category = std::input_iterator_tag; 140 Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)141 explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl) 142 : impl_(std::move(impl)) {} 143 144 /// Increments the iterator. 145 /// Only permitted for valid iterators (not past the end). 146 Iterator& operator++() { 147 impl_->next(); 148 return *this; 149 } 150 151 /// Returns the current batch. 152 /// Only permitted for valid iterators (not past the end). 153 Batch& operator*() { 154 return impl_->get(); 155 } 156 157 /// Returns a pointer to the current batch. 158 /// Only permitted for valid iterators (not past the end). 159 Batch* operator->() { 160 return &impl_->get(); 161 } 162 163 /// Compares two iterators for equality. 164 bool operator==(const Iterator& other) const { 165 return *impl_ == *other.impl_; 166 } 167 168 /// Compares two iterators for inequality. 169 bool operator!=(const Iterator& other) const { 170 return !(*this == other); 171 } 172 173 private: 174 /// Points either to a `ValidIterator` or to a `SentinelIterator`. 175 std::shared_ptr<detail::IteratorImpl<Batch>> impl_; 176 }; 177 } // namespace data 178 } // namespace torch 179