xref: /aosp_15_r20/external/icing/icing/index/embed/embedding-query-results.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
16 #define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
17 
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "icing/proto/search.pb.h"
22 #include "icing/store/document-id.h"
23 
24 namespace icing {
25 namespace lib {
26 
27 // A class to store results generated from embedding queries.
28 struct EmbeddingQueryResults {
29   // Maps from DocumentId to the list of matched embedding scores for that
30   // document, which will be used in the advanced scoring language to
31   // determine the results for the "this.matchedSemanticScores(...)" function.
32   using EmbeddingQueryScoreMap =
33       std::unordered_map<DocumentId, std::vector<double>>;
34 
35   // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap.
36   std::unordered_map<
37       int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code,
38                               EmbeddingQueryScoreMap>>
39       result_scores;
40 
41   // Get the score map for the given query_vector_index and metric_type. Returns
42   // nullptr if (query_vector_index, metric_type) does not exist in the
43   // result_scores map.
GetScoreMapEmbeddingQueryResults44   const EmbeddingQueryScoreMap* GetScoreMap(
45       int query_vector_index,
46       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) const {
47     // Check if a mapping exists for the query_vector_index
48     auto outer_it = result_scores.find(query_vector_index);
49     if (outer_it == result_scores.end()) {
50       return nullptr;
51     }
52     // Check if a mapping exists for the metric_type
53     auto inner_it = outer_it->second.find(metric_type);
54     if (inner_it == outer_it->second.end()) {
55       return nullptr;
56     }
57     return &inner_it->second;
58   }
59 
60   // Returns the matched scores for the given query_vector_index, metric_type,
61   // and doc_id. Returns nullptr if (query_vector_index, metric_type, doc_id)
62   // does not exist in the result_scores map.
GetMatchedScoresForDocumentEmbeddingQueryResults63   const std::vector<double>* GetMatchedScoresForDocument(
64       int query_vector_index,
65       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
66       DocumentId doc_id) const {
67     const EmbeddingQueryScoreMap* score_map =
68         GetScoreMap(query_vector_index, metric_type);
69     if (score_map == nullptr) {
70       return nullptr;
71     }
72     // Check if the doc_id exists in the score_map
73     auto scores_it = score_map->find(doc_id);
74     if (scores_it == score_map->end()) {
75       return nullptr;
76     }
77     return &scores_it->second;
78   }
79 };
80 
81 }  // namespace lib
82 }  // namespace icing
83 
84 #endif  // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
85