1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
18
19 #include <cstddef>
20 #include <iterator>
21 #include <utility>
22
23 #include "tensorflow/core/lib/gtl/iterator_range.h"
24
25 namespace xla {
26
27 // UnwrappingIterator is a transforming iterator that calls get() on the
28 // elements it returns.
29 //
30 // Together with tensorflow::gtl::iterator_range, this lets classes which
31 // contain a collection of smart pointers expose a view of raw pointers to
32 // consumers. For example:
33 //
34 // class MyContainer {
35 // public:
36 // tensorflow::gtl::iterator_range<
37 // UnwrappingIterator<std::vector<std::unique_ptr<Thing>>::iterator>>
38 // things() {
39 // return {MakeUnwrappingIterator(things_.begin()),
40 // MakeUnwrappingIterator(things_.end())};
41 // }
42 //
43 // tensorflow::gtl::iterator_range<UnwrappingIterator<
44 // std::vector<std::unique_ptr<Thing>>::const_iterator>>
45 // things() const {
46 // return {MakeUnwrappingIterator(things_.begin()),
47 // MakeUnwrappingIterator(things_.end())};
48 // }
49 //
50 // private:
51 // std::vector<std::unique_ptr<Thing>> things_;
52 // };
53 //
54 // MyContainer container = ...;
55 // for (Thing* t : container.things()) {
56 // ...
57 // }
58 //
59 // For simplicity, UnwrappingIterator is currently unconditionally an
60 // input_iterator -- it doesn't inherit any superpowers NestedIterator may have.
61 template <typename NestedIter>
62 class UnwrappingIterator {
63 public:
64 using iterator_category = std::input_iterator_tag;
65 using value_type = decltype(std::declval<NestedIter>()->get());
66 using difference_type = ptrdiff_t;
67 using pointer = value_type*;
68 using reference = value_type&;
69
UnwrappingIterator(NestedIter iter)70 explicit UnwrappingIterator(NestedIter iter) : iter_(std::move(iter)) {}
71
72 auto operator*() -> value_type { return iter_->get(); }
73 UnwrappingIterator& operator++() {
74 ++iter_;
75 return *this;
76 }
77 UnwrappingIterator operator++(int) {
78 UnwrappingIterator temp(iter_);
79 operator++();
80 return temp;
81 }
82
83 friend bool operator==(const UnwrappingIterator& a,
84 const UnwrappingIterator& b) {
85 return a.iter_ == b.iter_;
86 }
87
88 friend bool operator!=(const UnwrappingIterator& a,
89 const UnwrappingIterator& b) {
90 return !(a == b);
91 }
92
93 private:
94 NestedIter iter_;
95 };
96
97 template <typename NestedIter>
MakeUnwrappingIterator(NestedIter iter)98 UnwrappingIterator<NestedIter> MakeUnwrappingIterator(NestedIter iter) {
99 return UnwrappingIterator<NestedIter>(std::move(iter));
100 }
101
102 // An iterator that filters out values where the predicate(value) evaluates to
103 // false. An unwrapping iterator can be nested inside a filtering iterator to
104 // also unwrap smart pointers.
105 template <typename NestedIter, typename UnaryPredicate>
106 class FilteringIterator {
107 public:
108 using iterator_category = std::input_iterator_tag;
109 using value_type = decltype(*std::declval<NestedIter>());
110 using difference_type = ptrdiff_t;
111 using pointer = value_type*;
112 using reference = value_type&;
113
FilteringIterator(NestedIter iter,NestedIter end_iter,UnaryPredicate pred)114 FilteringIterator(NestedIter iter, NestedIter end_iter, UnaryPredicate pred)
115 : iter_(std::move(iter)),
116 end_iter_(std::move(end_iter)),
117 pred_(std::move(pred)) {
118 if (iter_ != end_iter_ && !pred_(**this)) {
119 ++*this;
120 }
121 }
122
123 auto operator*() -> value_type { return *iter_; }
124 FilteringIterator& operator++() {
125 do {
126 ++iter_;
127 } while (iter_ != end_iter_ && !pred_(**this));
128 return *this;
129 }
130 FilteringIterator operator++(int) {
131 FilteringIterator temp(iter_, end_iter_, pred_);
132 operator++();
133 return temp;
134 }
135
136 friend bool operator==(const FilteringIterator& a,
137 const FilteringIterator& b) {
138 return a.iter_ == b.iter_;
139 }
140
141 friend bool operator!=(const FilteringIterator& a,
142 const FilteringIterator& b) {
143 return !(a == b);
144 }
145
146 private:
147 NestedIter iter_;
148 NestedIter end_iter_;
149 UnaryPredicate pred_;
150 };
151
152 template <typename NestedIter, typename UnaryPredicate>
153 using FilteringUnwrappingIterator =
154 FilteringIterator<UnwrappingIterator<NestedIter>, UnaryPredicate>;
155
156 // Create and return a filtering unwrapping iterator.
157 template <typename NestedIter, typename UnaryPredicate>
158 FilteringUnwrappingIterator<NestedIter, UnaryPredicate>
MakeFilteringUnwrappingIterator(NestedIter iter,NestedIter end_iter,UnaryPredicate pred)159 MakeFilteringUnwrappingIterator(NestedIter iter, NestedIter end_iter,
160 UnaryPredicate pred) {
161 return FilteringUnwrappingIterator<NestedIter, UnaryPredicate>(
162 MakeUnwrappingIterator(iter), MakeUnwrappingIterator(end_iter),
163 std::move(pred));
164 }
165
166 // Create and return a filtering unwrapping iterator range.
167 template <typename NestedIter, typename UnaryPredicate>
168 tensorflow::gtl::iterator_range<
169 FilteringUnwrappingIterator<NestedIter, UnaryPredicate>>
MakeFilteringUnwrappingIteratorRange(NestedIter begin_iter,NestedIter end_iter,UnaryPredicate pred)170 MakeFilteringUnwrappingIteratorRange(NestedIter begin_iter, NestedIter end_iter,
171 UnaryPredicate pred) {
172 return {MakeFilteringUnwrappingIterator(begin_iter, end_iter, pred),
173 MakeFilteringUnwrappingIterator(end_iter, end_iter, pred)};
174 }
175
176 } // namespace xla
177
178 #endif // TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
179