xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/iterator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/utils/variadic.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <functional>
9 #include <iterator>
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 
14 namespace torch {
15 namespace data {
16 namespace detail {
17 // For increased safety and more separated logic, this implementation of
18 // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
19 // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
20 // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
21 // the same object. When the `ValidIterator` becomes exhausted, it compares
22 // equal to the `SentinelIterator`, but not before. Half the code here is to
23 // implement double dispatch for the comparison. Got damnit, C++.
24 
25 template <typename Batch>
26 struct ValidIterator;
27 
28 template <typename Batch>
29 struct SentinelIterator;
30 
31 /// Base class for the `ValidIterator` and `SentinelIterator`
32 template <typename Batch>
33 struct IteratorImpl {
34   virtual ~IteratorImpl() = default;
35   virtual void next() = 0;
36   virtual Batch& get() = 0;
37   virtual bool operator==(const IteratorImpl& other) const = 0;
38   virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
39   virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
40 };
41 
42 template <typename Batch>
43 struct ValidIterator : public IteratorImpl<Batch> {
44   using BatchProducer = std::function<std::optional<Batch>()>;
45 
ValidIteratorValidIterator46   explicit ValidIterator(BatchProducer next_batch)
47       : next_batch_(std::move(next_batch)) {}
48 
49   /// Fetches the next batch.
nextValidIterator50   void next() override {
51     // If we didn't get the very first batch yet, get it now.
52     lazy_initialize();
53     TORCH_CHECK(
54         batch_.has_value(), "Attempted to increment iterator past the end");
55     // Increment to the next batch.
56     batch_ = next_batch_();
57   }
58 
59   /// Returns the current batch. The precondition for this operation to not
60   /// throw an exception is that it has been compared to the `SentinelIterator`
61   /// and did not compare equal.
getValidIterator62   Batch& get() override {
63     // If we didn't get the very first batch yet, get it now.
64     lazy_initialize();
65     TORCH_CHECK(
66         batch_.has_value(),
67         "Attempted to dereference iterator that was past the end");
68     return batch_.value();
69   }
70 
71   /// Does double dispatch.
72   bool operator==(const IteratorImpl<Batch>& other) const override {
73     return other == *this;
74   }
75 
76   /// A `ValidIterator` is equal to the `SentinelIterator` iff. the
77   /// `ValidIterator` has reached the end of the dataloader.
78   bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
79     lazy_initialize();
80     return !batch_;
81   }
82 
83   /// Returns true if the memory address of `other` equals that of `this`.
84   bool operator==(const ValidIterator<Batch>& other) const override {
85     return &other == this;
86   }
87 
88   /// Gets the very first batch if it has not yet been fetched.
lazy_initializeValidIterator89   void lazy_initialize() const {
90     if (!initialized_) {
91       batch_ = next_batch_();
92       initialized_ = true;
93     }
94   }
95 
96   BatchProducer next_batch_;
97   mutable std::optional<Batch> batch_;
98   mutable bool initialized_ = false;
99 };
100 
101 template <typename Batch>
102 struct SentinelIterator : public IteratorImpl<Batch> {
nextSentinelIterator103   void next() override {
104     AT_ERROR(
105         "Incrementing the DataLoader's past-the-end iterator is not allowed");
106   }
107 
getSentinelIterator108   Batch& get() override {
109     AT_ERROR(
110         "Dereferencing the DataLoader's past-the-end iterator is not allowed");
111   }
112 
113   /// Does double dispatch.
114   bool operator==(const IteratorImpl<Batch>& other) const override {
115     return other == *this;
116   }
117 
118   /// Calls the comparison operator between `ValidIterator` and
119   /// `SentinelIterator`.
120   bool operator==(const ValidIterator<Batch>& other) const override {
121     return other == *this;
122   }
123 
124   /// Sentinel iterators always compare equal.
125   bool operator==(const SentinelIterator<Batch>& other) const override {
126     return true;
127   }
128 };
129 } // namespace detail
130 
131 template <typename Batch>
132 class Iterator {
133  public:
134   // Type aliases to make the class recognized as a proper iterator.
135   using difference_type = std::ptrdiff_t;
136   using value_type = Batch;
137   using pointer = Batch*;
138   using reference = Batch&;
139   using iterator_category = std::input_iterator_tag;
140 
Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)141   explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
142       : impl_(std::move(impl)) {}
143 
144   /// Increments the iterator.
145   /// Only permitted for valid iterators (not past the end).
146   Iterator& operator++() {
147     impl_->next();
148     return *this;
149   }
150 
151   /// Returns the current batch.
152   /// Only permitted for valid iterators (not past the end).
153   Batch& operator*() {
154     return impl_->get();
155   }
156 
157   /// Returns a pointer to the current batch.
158   /// Only permitted for valid iterators (not past the end).
159   Batch* operator->() {
160     return &impl_->get();
161   }
162 
163   /// Compares two iterators for equality.
164   bool operator==(const Iterator& other) const {
165     return *impl_ == *other.impl_;
166   }
167 
168   /// Compares two iterators for inequality.
169   bool operator!=(const Iterator& other) const {
170     return !(*this == other);
171   }
172 
173  private:
174   /// Points either to a `ValidIterator` or to a `SentinelIterator`.
175   std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
176 };
177 } // namespace data
178 } // namespace torch
179