xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/gtl/top_n.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This simple class finds the top n elements of an incrementally provided set
17 // of elements which you push one at a time.  If the number of elements exceeds
18 // n, the lowest elements are incrementally dropped.  At the end you get
19 // a vector of the top elements sorted in descending order (through Extract() or
20 // ExtractNondestructive()), or a vector of the top elements but not sorted
21 // (through ExtractUnsorted() or ExtractUnsortedNondestructive()).
22 //
23 // The value n is specified in the constructor.  If there are p elements pushed
24 // altogether:
25 //   The total storage requirements are O(min(n, p)) elements
26 //   The running time is O(p * log(min(n, p))) comparisons
27 // If n is a constant, the total storage required is a constant and the running
28 // time is linear in p.
29 //
30 // NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p)
31 // runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements,
32 // discarding the lowest n elements whenever the buffer is full using a linear-
33 // time median algorithm. This may have better performance when the input
34 // sequence is partially sorted.
35 //
36 // NOTE(zhifengc): This class should be redesigned to avoid reallocating a
37 // vector for each Extract.
38 
39 #ifndef TENSORFLOW_LIB_GTL_TOP_N_H_
40 #define TENSORFLOW_LIB_GTL_TOP_N_H_
41 
42 #include <stddef.h>
43 #include <algorithm>
44 #include <functional>
45 #include <string>
46 #include <vector>
47 
48 #include "tensorflow/core/platform/logging.h"
49 
50 namespace tensorflow {
51 namespace gtl {
52 
53 // Cmp is an stl binary predicate.  Note that Cmp is the "greater" predicate,
54 // not the more commonly used "less" predicate.
55 //
56 // If you use a "less" predicate here, the TopN will pick out the bottom N
57 // elements out of the ones passed to it, and it will return them sorted in
58 // ascending order.
59 //
60 // TopN is rule-of-zero copyable and movable if its members are.
61 template <class T, class Cmp = std::greater<T> >
62 class TopN {
63  public:
64   // The TopN is in one of the three states:
65   //
66   //  o UNORDERED: this is the state an instance is originally in,
67   //    where the elements are completely orderless.
68   //
69   //  o BOTTOM_KNOWN: in this state, we keep the invariant that there
70   //    is at least one element in it, and the lowest element is at
71   //    position 0. The elements in other positions remain
72   //    unsorted. This state is reached if the state was originally
73   //    UNORDERED and a peek_bottom() function call is invoked.
74   //
75   //  o HEAP_SORTED: in this state, the array is kept as a heap and
76   //    there are exactly limit_ elements in the array. This
77   //    state is reached when at least (limit_+1) elements are
78   //    pushed in.
79   //
80   //  The state transition graph is at follows:
81   //
82   //             peek_bottom()                (limit_+1) elements pushed
83   //  UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
84   //      |                                                           ^
85   //      |                (limit_+1) elements pushed                 |
86   //      +-----------------------------------------------------------+
87 
88   enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
89   using UnsortedIterator = typename std::vector<T>::const_iterator;
90 
91   // 'limit' is the maximum number of top results to return.
TopN(size_t limit)92   explicit TopN(size_t limit) : TopN(limit, Cmp()) {}
TopN(size_t limit,const Cmp & cmp)93   TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {}
94 
limit()95   size_t limit() const { return limit_; }
96 
97   // Number of elements currently held by this TopN object.  This
98   // will be no greater than 'limit' passed to the constructor.
size()99   size_t size() const { return elements_.size(); }
100 
empty()101   bool empty() const { return size() == 0; }
102 
103   // If you know how many elements you will push at the time you create the
104   // TopN object, you can call reserve to preallocate the memory that TopN
105   // will need to process all 'n' pushes.  Calling this method is optional.
reserve(size_t n)106   void reserve(size_t n) {
107     // We may need limit_+1 for the case where we transition from an unsorted
108     // set of limit_ elements to a heap.
109     elements_.reserve(std::min(n, limit_ + 1));
110   }
111 
112   // Push 'v'.  If the maximum number of elements was exceeded, drop the
113   // lowest element and return it in 'dropped' (if given). If the maximum is not
114   // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or
115   // nullptr, in which case it is not filled in.
116   // Requires: T is CopyAssignable, Swappable
push(const T & v)117   void push(const T &v) { push(v, nullptr); }
push(const T & v,T * dropped)118   void push(const T &v, T *dropped) { PushInternal(v, dropped); }
119 
120   // Move overloads of push.
121   // Requires: T is MoveAssignable, Swappable
push(T && v)122   void push(T &&v) {  // NOLINT(build/c++11)
123     push(std::move(v), nullptr);
124   }
push(T && v,T * dropped)125   void push(T &&v, T *dropped) {  // NOLINT(build/c++11)
126     PushInternal(std::move(v), dropped);
127   }
128 
129   // Peeks the bottom result without calling Extract()
130   const T &peek_bottom();
131 
132   // Extract the elements as a vector sorted in descending order.  The caller
133   // assumes ownership of the vector and must delete it when done.  This is a
134   // destructive operation.  The only method that can be called immediately
135   // after Extract() is Reset().
136   std::vector<T> *Extract();
137 
138   // Similar to Extract(), but makes no guarantees the elements are in sorted
139   // order.  As with Extract(), the caller assumes ownership of the vector and
140   // must delete it when done.  This is a destructive operation.  The only
141   // method that can be called immediately after ExtractUnsorted() is Reset().
142   std::vector<T> *ExtractUnsorted();
143 
144   // A non-destructive version of Extract(). Copy the elements in a new vector
145   // sorted in descending order and return it.  The caller assumes ownership of
146   // the new vector and must delete it when done.  After calling
147   // ExtractNondestructive(), the caller can continue to push() new elements.
148   std::vector<T> *ExtractNondestructive() const;
149 
150   // A non-destructive version of Extract(). Copy the elements to a given
151   // vector sorted in descending order. After calling
152   // ExtractNondestructive(), the caller can continue to push() new elements.
153   // Note:
154   //  1. The given argument must to be allocated.
155   //  2. Any data contained in the vector prior to the call will be deleted
156   //     from it. After the call the vector will contain only the elements
157   //     from the data structure.
158   void ExtractNondestructive(std::vector<T> *output) const;
159 
160   // A non-destructive version of ExtractUnsorted(). Copy the elements in a new
161   // vector and return it, with no guarantees the elements are in sorted order.
162   // The caller assumes ownership of the new vector and must delete it when
163   // done.  After calling ExtractUnsortedNondestructive(), the caller can
164   // continue to push() new elements.
165   std::vector<T> *ExtractUnsortedNondestructive() const;
166 
167   // A non-destructive version of ExtractUnsorted(). Copy the elements into
168   // a given vector, with no guarantees the elements are in sorted order.
169   // After calling ExtractUnsortedNondestructive(), the caller can continue
170   // to push() new elements.
171   // Note:
172   //  1. The given argument must to be allocated.
173   //  2. Any data contained in the vector prior to the call will be deleted
174   //     from it. After the call the vector will contain only the elements
175   //     from the data structure.
176   void ExtractUnsortedNondestructive(std::vector<T> *output) const;
177 
178   // Return an iterator to the beginning (end) of the container,
179   // with no guarantees about the order of iteration. These iterators are
180   // invalidated by mutation of the data structure.
unsorted_begin()181   UnsortedIterator unsorted_begin() const { return elements_.begin(); }
unsorted_end()182   UnsortedIterator unsorted_end() const { return elements_.end(); }
183 
184   // Accessor for comparator template argument.
comparator()185   Cmp *comparator() { return &cmp_; }
186 
187   // This removes all elements.  If Extract() or ExtractUnsorted() have been
188   // called, this will put it back in an empty but useable state.
189   void Reset();
190 
191  private:
192   template <typename U>
193   void PushInternal(U &&v, T *dropped);  // NOLINT(build/c++11)
194 
195   // elements_ can be in one of two states:
196   //   elements_.size() <= limit_ && state_ != HEAP_SORTED:
197   //      elements_ is an unsorted vector of elements pushed so far.
198   //   elements_.size() == limit_ && state_ == HEAP_SORTED:
199   //      elements_ is an stl heap.
200   std::vector<T> elements_;
201   size_t limit_;  // Maximum number of elements to find
202   Cmp cmp_;       // Greater-than comparison function
203   State state_ = UNORDERED;
204 };
205 
206 // ----------------------------------------------------------------------
207 // Implementations of non-inline functions
208 
209 template <class T, class Cmp>
210 template <typename U>
PushInternal(U && v,T * dropped)211 void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) {  // NOLINT(build/c++11)
212   if (limit_ == 0) {
213     if (dropped) *dropped = std::forward<U>(v);  // NOLINT(build/c++11)
214     return;
215   }
216   if (state_ != HEAP_SORTED) {
217     // We may temporarily extend one beyond limit_ elements here.  This is
218     // necessary for finding and removing the smallest element.
219     elements_.push_back(std::forward<U>(v));  // NOLINT(build/c++11)
220     if (elements_.size() == limit_ + 1) {
221       // Transition from unsorted vector to a heap.
222       std::make_heap(elements_.begin(), elements_.end(), cmp_);
223       std::pop_heap(elements_.begin(), elements_.end(), cmp_);
224       if (dropped) *dropped = std::move(elements_.back());
225       elements_.pop_back();  // Restore to size limit_.
226       state_ = HEAP_SORTED;
227     } else if (state_ == UNORDERED ||
228                cmp_(elements_.back(), elements_.front())) {
229       // Easy case: we just push the new element back
230     } else {
231       // To maintain the BOTTOM_KNOWN state, we need to make sure that
232       // the element at position 0 is always the smallest. So we put
233       // the new element at position 0 and push the original bottom
234       // element in the back.
235       // Warning: this code is subtle.
236       using std::swap;
237       swap(elements_.front(), elements_.back());
238     }
239 
240   } else {
241     // Only insert the new element if it is greater than the least element.
242     if (cmp_(v, elements_.front())) {
243       // Remove the top (smallest) element of the min heap, then push the new
244       // value in.
245       std::pop_heap(elements_.begin(), elements_.end(), cmp_);
246       if (dropped) *dropped = std::move(elements_.back());
247       elements_.back() = std::forward<U>(v);
248       std::push_heap(elements_.begin(), elements_.end(), cmp_);
249     } else {
250       if (dropped) *dropped = std::forward<U>(v);  // NOLINT(build/c++11)
251     }
252   }
253 }
254 
255 template <class T, class Cmp>
peek_bottom()256 const T &TopN<T, Cmp>::peek_bottom() {
257   CHECK(!empty());
258   if (state_ == UNORDERED) {
259     // We need to do a linear scan to find out the bottom element
260     int min_candidate = 0;
261     for (size_t i = 1; i < elements_.size(); ++i) {
262       if (cmp_(elements_[min_candidate], elements_[i])) {
263         min_candidate = i;
264       }
265     }
266     // By swapping the element at position 0 and the minimal
267     // element, we transition to the BOTTOM_KNOWN state
268     if (min_candidate != 0) {
269       using std::swap;
270       swap(elements_[0], elements_[min_candidate]);
271     }
272     state_ = BOTTOM_KNOWN;
273   }
274   return elements_.front();
275 }
276 
277 template <class T, class Cmp>
Extract()278 std::vector<T> *TopN<T, Cmp>::Extract() {
279   auto out = new std::vector<T>;
280   out->swap(elements_);
281   if (state_ != HEAP_SORTED) {
282     std::sort(out->begin(), out->end(), cmp_);
283   } else {
284     std::sort_heap(out->begin(), out->end(), cmp_);
285   }
286   return out;
287 }
288 
289 template <class T, class Cmp>
ExtractUnsorted()290 std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() {
291   auto out = new std::vector<T>;
292   out->swap(elements_);
293   return out;
294 }
295 
296 template <class T, class Cmp>
ExtractNondestructive()297 std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const {
298   auto out = new std::vector<T>;
299   ExtractNondestructive(out);
300   return out;
301 }
302 
303 template <class T, class Cmp>
ExtractNondestructive(std::vector<T> * output)304 void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
305   CHECK(output);
306   *output = elements_;
307   if (state_ != HEAP_SORTED) {
308     std::sort(output->begin(), output->end(), cmp_);
309   } else {
310     std::sort_heap(output->begin(), output->end(), cmp_);
311   }
312 }
313 
314 template <class T, class Cmp>
ExtractUnsortedNondestructive()315 std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const {
316   auto elements = new std::vector<T>;
317   ExtractUnsortedNondestructive(elements);
318   return elements;
319 }
320 
321 template <class T, class Cmp>
ExtractUnsortedNondestructive(std::vector<T> * output)322 void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
323   CHECK(output);
324   *output = elements_;
325 }
326 
327 template <class T, class Cmp>
Reset()328 void TopN<T, Cmp>::Reset() {
329   elements_.clear();
330   state_ = UNORDERED;
331 }
332 
333 }  // namespace gtl
334 }  // namespace tensorflow
335 
336 #endif  // TENSORFLOW_LIB_GTL_TOP_N_H_
337