xref: /aosp_15_r20/external/icing/icing/scoring/ranker.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/scoring/ranker.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <utility>
20 #include <vector>
21 
22 #include "icing/text_classifier/lib3/utils/base/statusor.h"
23 #include "icing/absl_ports/canonical_errors.h"
24 #include "icing/index/term-metadata.h"
25 #include "icing/scoring/scored-document-hit.h"
26 #include "icing/util/logging.h"
27 
28 namespace icing {
29 namespace lib {
30 
31 namespace {
32 // For all the heap manipulations in this file, we use a vector to represent the
33 // heap. The element at index 0 is the root node. For any node at index i, its
34 // left child node is at 2 * i + 1, its right child node is at 2 * i + 2.
35 
36 // Helper function to wrap the heapify algorithm, it heapifies the target
37 // subtree node in place.
38 // TODO(b/152934343) refactor the heapify function and making it into a class.
Heapify(std::vector<ScoredDocumentHit> * scored_document_hits,int target_subtree_root_index,const ScoredDocumentHitComparator & scored_document_hit_comparator)39 void Heapify(
40     std::vector<ScoredDocumentHit>* scored_document_hits,
41     int target_subtree_root_index,
42     const ScoredDocumentHitComparator& scored_document_hit_comparator) {
43   const int heap_size = scored_document_hits->size();
44   if (target_subtree_root_index >= heap_size) {
45     return;
46   }
47 
48   // Initializes subtree root as the current best node.
49   int best = target_subtree_root_index;
50   // If we represent a heap in an array/vector, indices of left and right
51   // children can be calculated.
52   const int left = target_subtree_root_index * 2 + 1;
53   const int right = target_subtree_root_index * 2 + 2;
54 
55   // If left child is better than current best
56   if (left < heap_size &&
57       scored_document_hit_comparator(scored_document_hits->at(left),
58                                      scored_document_hits->at(best))) {
59     best = left;
60   }
61 
62   // If right child is better than current best
63   if (right < heap_size &&
64       scored_document_hit_comparator(scored_document_hits->at(right),
65                                      scored_document_hits->at(best))) {
66     best = right;
67   }
68 
69   // If the best is not the subtree root, swap and continue heapifying the lower
70   // level subtree
71   if (best != target_subtree_root_index) {
72     std::swap(scored_document_hits->at(best),
73               scored_document_hits->at(target_subtree_root_index));
74     Heapify(scored_document_hits, best, scored_document_hit_comparator);
75   }
76 }
77 
78 // Heapify the given term vector from top to bottom. Call it after add or
79 // replace an element at the front of the vector.
HeapifyTermDown(std::vector<TermMetadata> & scored_terms,int target_subtree_root_index)80 void HeapifyTermDown(std::vector<TermMetadata>& scored_terms,
81                      int target_subtree_root_index) {
82   int heap_size = scored_terms.size();
83   if (target_subtree_root_index >= heap_size) {
84     return;
85   }
86 
87   // Initializes subtree root as the current minimum node.
88   int min = target_subtree_root_index;
89   // If we represent a heap in an array/vector, indices of left and right
90   // children can be calculated as such.
91   const int left = target_subtree_root_index * 2 + 1;
92   const int right = target_subtree_root_index * 2 + 2;
93 
94   // If left child is smaller than current minimum.
95   if (left < heap_size &&
96       scored_terms.at(left).score < scored_terms.at(min).score) {
97     min = left;
98   }
99 
100   // If right child is smaller than current minimum.
101   if (right < heap_size &&
102       scored_terms.at(right).score < scored_terms.at(min).score) {
103     min = right;
104   }
105 
106   // If the minimum is not the subtree root, swap and continue heapifying the
107   // lower level subtree.
108   if (min != target_subtree_root_index) {
109     std::swap(scored_terms.at(min), scored_terms.at(target_subtree_root_index));
110     HeapifyTermDown(scored_terms, min);
111   }
112 }
113 
114 // Heapify the given term vector from bottom to top. Call it after add an
115 // element at the end of the vector.
HeapifyTermUp(std::vector<TermMetadata> & scored_terms,int target_subtree_child_index)116 void HeapifyTermUp(std::vector<TermMetadata>& scored_terms,
117                    int target_subtree_child_index) {
118   // If we represent a heap in an array/vector, indices of root can be
119   // calculated as such.
120   const int root = (target_subtree_child_index + 1) / 2 - 1;
121 
122   // If the current child is smaller than the root, swap and continue heapifying
123   // the upper level subtree
124   if (root >= 0 && scored_terms.at(target_subtree_child_index).score <
125                        scored_terms.at(root).score) {
126     std::swap(scored_terms.at(root),
127               scored_terms.at(target_subtree_child_index));
128     HeapifyTermUp(scored_terms, root);
129   }
130 }
131 
PopRootTerm(std::vector<TermMetadata> & scored_terms)132 TermMetadata PopRootTerm(std::vector<TermMetadata>& scored_terms) {
133   if (scored_terms.empty()) {
134     // Return an invalid TermMetadata as a sentinel value.
135     return TermMetadata(/*content_in=*/"", /*hit_count_in=*/-1);
136   }
137 
138   // Steps to extract root from heap:
139   // 1. copy out root
140   TermMetadata root = scored_terms.at(0);
141   const size_t last_node_index = scored_terms.size() - 1;
142   // 2. swap root and the last node
143   std::swap(scored_terms.at(0), scored_terms.at(last_node_index));
144   // 3. remove last node
145   scored_terms.pop_back();
146   // 4. heapify root
147   HeapifyTermDown(scored_terms, /*target_subtree_root_index=*/0);
148   return root;
149 }
150 
151 }  // namespace
152 
BuildHeapInPlace(std::vector<ScoredDocumentHit> * scored_document_hits,const ScoredDocumentHitComparator & scored_document_hit_comparator)153 void BuildHeapInPlace(
154     std::vector<ScoredDocumentHit>* scored_document_hits,
155     const ScoredDocumentHitComparator& scored_document_hit_comparator) {
156   const int heap_size = scored_document_hits->size();
157   // Since we use a vector to represent the heap, [size / 2 - 1] is the index
158   // of the parent node of the last node.
159   for (int subtree_root_index = heap_size / 2 - 1; subtree_root_index >= 0;
160        subtree_root_index--) {
161     Heapify(scored_document_hits, subtree_root_index,
162             scored_document_hit_comparator);
163   }
164 }
165 
PushToTermHeap(TermMetadata term,int number_to_return,std::vector<TermMetadata> & scored_terms_heap)166 void PushToTermHeap(TermMetadata term, int number_to_return,
167                     std::vector<TermMetadata>& scored_terms_heap) {
168   if (scored_terms_heap.size() < number_to_return) {
169     scored_terms_heap.push_back(std::move(term));
170     // We insert at end, so we should heapify bottom up.
171     HeapifyTermUp(scored_terms_heap, scored_terms_heap.size() - 1);
172   } else if (scored_terms_heap.at(0).score < term.score) {
173     scored_terms_heap.at(0) = std::move(term);
174     // We insert at root, so we should heapify top down.
175     HeapifyTermDown(scored_terms_heap, /*target_subtree_root_index=*/0);
176   }
177 }
178 
PopNextTopResultFromHeap(std::vector<ScoredDocumentHit> * scored_document_hits_heap,const ScoredDocumentHitComparator & scored_document_hit_comparator)179 libtextclassifier3::StatusOr<ScoredDocumentHit> PopNextTopResultFromHeap(
180     std::vector<ScoredDocumentHit>* scored_document_hits_heap,
181     const ScoredDocumentHitComparator& scored_document_hit_comparator) {
182   if (scored_document_hits_heap->empty()) {
183     // An invalid ScoredDocumentHit
184     return absl_ports::ResourceExhaustedError("Heap is empty");
185   }
186 
187   // Steps to extract root from heap:
188   // 1. copy out root
189   ScoredDocumentHit root = std::move(scored_document_hits_heap->at(0));
190   const size_t last_node_index = scored_document_hits_heap->size() - 1;
191   // 2. swap root and the last node
192   std::swap(scored_document_hits_heap->at(0),
193             scored_document_hits_heap->at(last_node_index));
194   // 3. remove last node
195   scored_document_hits_heap->pop_back();
196   // 4. heapify root
197   Heapify(scored_document_hits_heap, /*target_subtree_root_index=*/0,
198           scored_document_hit_comparator);
199   return root;
200 }
201 
PopTopResultsFromHeap(std::vector<ScoredDocumentHit> * scored_document_hits_heap,int num_results,const ScoredDocumentHitComparator & scored_document_hit_comparator)202 std::vector<ScoredDocumentHit> PopTopResultsFromHeap(
203     std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results,
204     const ScoredDocumentHitComparator& scored_document_hit_comparator) {
205   std::vector<ScoredDocumentHit> scored_document_hit_result;
206   int result_size = std::min(
207       num_results, static_cast<int>(scored_document_hits_heap->size()));
208   while (result_size-- > 0) {
209     libtextclassifier3::StatusOr<ScoredDocumentHit> next_best_document_hit_or =
210         PopNextTopResultFromHeap(scored_document_hits_heap,
211                                  scored_document_hit_comparator);
212     if (next_best_document_hit_or.ok()) {
213       scored_document_hit_result.push_back(
214           std::move(next_best_document_hit_or).ValueOrDie());
215     } else {
216       ICING_VLOG(1) << next_best_document_hit_or.status().error_message();
217     }
218   }
219   return scored_document_hit_result;
220 }
221 
PopAllTermsFromHeap(std::vector<TermMetadata> & scored_terms_heap)222 std::vector<TermMetadata> PopAllTermsFromHeap(
223     std::vector<TermMetadata>& scored_terms_heap) {
224   std::vector<TermMetadata> top_term_result;
225   top_term_result.reserve(scored_terms_heap.size());
226   while (!scored_terms_heap.empty()) {
227     top_term_result.push_back(PopRootTerm(scored_terms_heap));
228   }
229   return top_term_result;
230 }
231 
232 }  // namespace lib
233 }  // namespace icing
234