xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/iterator_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
18 
19 #include <cstddef>
20 #include <iterator>
21 #include <utility>
22 
23 #include "tensorflow/core/lib/gtl/iterator_range.h"
24 
25 namespace xla {
26 
27 // UnwrappingIterator is a transforming iterator that calls get() on the
28 // elements it returns.
29 //
30 // Together with tensorflow::gtl::iterator_range, this lets classes which
31 // contain a collection of smart pointers expose a view of raw pointers to
32 // consumers.  For example:
33 //
34 //  class MyContainer {
35 //   public:
36 //    tensorflow::gtl::iterator_range<
37 //        UnwrappingIterator<std::vector<std::unique_ptr<Thing>>::iterator>>
38 //    things() {
39 //      return {MakeUnwrappingIterator(things_.begin()),
40 //              MakeUnwrappingIterator(things_.end())};
41 //    }
42 //
43 //    tensorflow::gtl::iterator_range<UnwrappingIterator<
44 //        std::vector<std::unique_ptr<Thing>>::const_iterator>>
45 //    things() const {
46 //      return {MakeUnwrappingIterator(things_.begin()),
47 //              MakeUnwrappingIterator(things_.end())};
48 //    }
49 //
50 //   private:
51 //    std::vector<std::unique_ptr<Thing>> things_;
52 //  };
53 //
54 //  MyContainer container = ...;
55 //  for (Thing* t : container.things()) {
56 //    ...
57 //  }
58 //
59 // For simplicity, UnwrappingIterator is currently unconditionally an
60 // input_iterator -- it doesn't inherit any superpowers NestedIterator may have.
61 template <typename NestedIter>
62 class UnwrappingIterator {
63  public:
64   using iterator_category = std::input_iterator_tag;
65   using value_type = decltype(std::declval<NestedIter>()->get());
66   using difference_type = ptrdiff_t;
67   using pointer = value_type*;
68   using reference = value_type&;
69 
UnwrappingIterator(NestedIter iter)70   explicit UnwrappingIterator(NestedIter iter) : iter_(std::move(iter)) {}
71 
72   auto operator*() -> value_type { return iter_->get(); }
73   UnwrappingIterator& operator++() {
74     ++iter_;
75     return *this;
76   }
77   UnwrappingIterator operator++(int) {
78     UnwrappingIterator temp(iter_);
79     operator++();
80     return temp;
81   }
82 
83   friend bool operator==(const UnwrappingIterator& a,
84                          const UnwrappingIterator& b) {
85     return a.iter_ == b.iter_;
86   }
87 
88   friend bool operator!=(const UnwrappingIterator& a,
89                          const UnwrappingIterator& b) {
90     return !(a == b);
91   }
92 
93  private:
94   NestedIter iter_;
95 };
96 
97 template <typename NestedIter>
MakeUnwrappingIterator(NestedIter iter)98 UnwrappingIterator<NestedIter> MakeUnwrappingIterator(NestedIter iter) {
99   return UnwrappingIterator<NestedIter>(std::move(iter));
100 }
101 
102 // An iterator that filters out values where the predicate(value) evaluates to
103 // false. An unwrapping iterator can be nested inside a filtering iterator to
104 // also unwrap smart pointers.
105 template <typename NestedIter, typename UnaryPredicate>
106 class FilteringIterator {
107  public:
108   using iterator_category = std::input_iterator_tag;
109   using value_type = decltype(*std::declval<NestedIter>());
110   using difference_type = ptrdiff_t;
111   using pointer = value_type*;
112   using reference = value_type&;
113 
FilteringIterator(NestedIter iter,NestedIter end_iter,UnaryPredicate pred)114   FilteringIterator(NestedIter iter, NestedIter end_iter, UnaryPredicate pred)
115       : iter_(std::move(iter)),
116         end_iter_(std::move(end_iter)),
117         pred_(std::move(pred)) {
118     if (iter_ != end_iter_ && !pred_(**this)) {
119       ++*this;
120     }
121   }
122 
123   auto operator*() -> value_type { return *iter_; }
124   FilteringIterator& operator++() {
125     do {
126       ++iter_;
127     } while (iter_ != end_iter_ && !pred_(**this));
128     return *this;
129   }
130   FilteringIterator operator++(int) {
131     FilteringIterator temp(iter_, end_iter_, pred_);
132     operator++();
133     return temp;
134   }
135 
136   friend bool operator==(const FilteringIterator& a,
137                          const FilteringIterator& b) {
138     return a.iter_ == b.iter_;
139   }
140 
141   friend bool operator!=(const FilteringIterator& a,
142                          const FilteringIterator& b) {
143     return !(a == b);
144   }
145 
146  private:
147   NestedIter iter_;
148   NestedIter end_iter_;
149   UnaryPredicate pred_;
150 };
151 
152 template <typename NestedIter, typename UnaryPredicate>
153 using FilteringUnwrappingIterator =
154     FilteringIterator<UnwrappingIterator<NestedIter>, UnaryPredicate>;
155 
156 // Create and return a filtering unwrapping iterator.
157 template <typename NestedIter, typename UnaryPredicate>
158 FilteringUnwrappingIterator<NestedIter, UnaryPredicate>
MakeFilteringUnwrappingIterator(NestedIter iter,NestedIter end_iter,UnaryPredicate pred)159 MakeFilteringUnwrappingIterator(NestedIter iter, NestedIter end_iter,
160                                 UnaryPredicate pred) {
161   return FilteringUnwrappingIterator<NestedIter, UnaryPredicate>(
162       MakeUnwrappingIterator(iter), MakeUnwrappingIterator(end_iter),
163       std::move(pred));
164 }
165 
166 // Create and return a filtering unwrapping iterator range.
167 template <typename NestedIter, typename UnaryPredicate>
168 tensorflow::gtl::iterator_range<
169     FilteringUnwrappingIterator<NestedIter, UnaryPredicate>>
MakeFilteringUnwrappingIteratorRange(NestedIter begin_iter,NestedIter end_iter,UnaryPredicate pred)170 MakeFilteringUnwrappingIteratorRange(NestedIter begin_iter, NestedIter end_iter,
171                                      UnaryPredicate pred) {
172   return {MakeFilteringUnwrappingIterator(begin_iter, end_iter, pred),
173           MakeFilteringUnwrappingIterator(end_iter, end_iter, pred)};
174 }
175 
176 }  // namespace xla
177 
178 #endif  // TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
179