1 // Copyright (C) 2024 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ 16 #define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ 17 18 #include <unordered_map> 19 #include <vector> 20 21 #include "icing/proto/search.pb.h" 22 #include "icing/store/document-id.h" 23 24 namespace icing { 25 namespace lib { 26 27 // A class to store results generated from embedding queries. 28 struct EmbeddingQueryResults { 29 // Maps from DocumentId to the list of matched embedding scores for that 30 // document, which will be used in the advanced scoring language to 31 // determine the results for the "this.matchedSemanticScores(...)" function. 32 using EmbeddingQueryScoreMap = 33 std::unordered_map<DocumentId, std::vector<double>>; 34 35 // Maps from (query_vector_index, metric_type) to EmbeddingQueryScoreMap. 36 std::unordered_map< 37 int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code, 38 EmbeddingQueryScoreMap>> 39 result_scores; 40 41 // Get the score map for the given query_vector_index and metric_type. Returns 42 // nullptr if (query_vector_index, metric_type) does not exist in the 43 // result_scores map. GetScoreMapEmbeddingQueryResults44 const EmbeddingQueryScoreMap* GetScoreMap( 45 int query_vector_index, 46 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) const { 47 // Check if a mapping exists for the query_vector_index 48 auto outer_it = result_scores.find(query_vector_index); 49 if (outer_it == result_scores.end()) { 50 return nullptr; 51 } 52 // Check if a mapping exists for the metric_type 53 auto inner_it = outer_it->second.find(metric_type); 54 if (inner_it == outer_it->second.end()) { 55 return nullptr; 56 } 57 return &inner_it->second; 58 } 59 60 // Returns the matched scores for the given query_vector_index, metric_type, 61 // and doc_id. Returns nullptr if (query_vector_index, metric_type, doc_id) 62 // does not exist in the result_scores map. GetMatchedScoresForDocumentEmbeddingQueryResults63 const std::vector<double>* GetMatchedScoresForDocument( 64 int query_vector_index, 65 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 66 DocumentId doc_id) const { 67 const EmbeddingQueryScoreMap* score_map = 68 GetScoreMap(query_vector_index, metric_type); 69 if (score_map == nullptr) { 70 return nullptr; 71 } 72 // Check if the doc_id exists in the score_map 73 auto scores_it = score_map->find(doc_id); 74 if (scores_it == score_map->end()) { 75 return nullptr; 76 } 77 return &scores_it->second; 78 } 79 }; 80 81 } // namespace lib 82 } // namespace icing 83 84 #endif // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_ 85