1 // Copyright (C) 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 16 #define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <utility> 24 #include <vector> 25 26 #include "icing/text_classifier/lib3/utils/base/status.h" 27 #include "icing/text_classifier/lib3/utils/base/statusor.h" 28 #include "icing/feature-flags.h" 29 #include "icing/index/embed/embedding-query-results.h" 30 #include "icing/index/hit/doc-hit-info.h" 31 #include "icing/index/iterator/doc-hit-info-iterator.h" 32 #include "icing/join/join-children-fetcher.h" 33 #include "icing/proto/scoring.pb.h" 34 #include "icing/schema/schema-store.h" 35 #include "icing/scoring/advanced_scoring/score-expression.h" 36 #include "icing/scoring/bm25f-calculator.h" 37 #include "icing/scoring/scorer.h" 38 #include "icing/scoring/section-weights.h" 39 #include "icing/store/document-store.h" 40 #include "icing/util/logging.h" 41 42 namespace icing { 43 namespace lib { 44 45 class AdvancedScorer : public Scorer { 46 public: 47 // Returns: 48 // A AdvancedScorer instance on success 49 // FAILED_PRECONDITION on any null pointer input 50 // INVALID_ARGUMENT if fails to create an instance 51 static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create( 52 const ScoringSpecProto& scoring_spec, double default_score, 53 SearchSpecProto::EmbeddingQueryMetricType::Code 54 default_semantic_metric_type, 55 const DocumentStore* document_store, const SchemaStore* schema_store, 56 int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher, 57 const EmbeddingQueryResults* embedding_query_results, 58 const FeatureFlags* feature_flags); 59 GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)60 double GetScore(const DocHitInfo& hit_info, 61 const DocHitInfoIterator* query_it) override { 62 return GetScoreFromExpression(score_expression_.get(), hit_info, query_it); 63 } 64 GetAdditionalScores(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)65 std::vector<double> GetAdditionalScores( 66 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override { 67 std::vector<double> additional_scores; 68 additional_scores.reserve(additional_score_expressions_.size()); 69 for (const auto& additional_score_expression : 70 additional_score_expressions_) { 71 additional_scores.push_back(GetScoreFromExpression( 72 additional_score_expression.get(), hit_info, query_it)); 73 } 74 return additional_scores; 75 } 76 PrepareToScore(std::unordered_map<std::string,std::unique_ptr<DocHitInfoIterator>> * query_term_iterators)77 void PrepareToScore( 78 std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* 79 query_term_iterators) override { 80 if (query_term_iterators == nullptr || query_term_iterators->empty()) { 81 return; 82 } 83 bm25f_calculator_->PrepareToScore(query_term_iterators); 84 } 85 is_constant()86 bool is_constant() const { return score_expression_->is_constant(); } 87 88 private: AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,std::vector<std::unique_ptr<ScoreExpression>> additional_score_expressions,std::unique_ptr<SectionWeights> section_weights,std::unique_ptr<Bm25fCalculator> bm25f_calculator,std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map,std::unique_ptr<std::unordered_set<ScoringFeatureType>> scoring_feature_types_enabled,double default_score)89 explicit AdvancedScorer( 90 std::unique_ptr<ScoreExpression> score_expression, 91 std::vector<std::unique_ptr<ScoreExpression>> 92 additional_score_expressions, 93 std::unique_ptr<SectionWeights> section_weights, 94 std::unique_ptr<Bm25fCalculator> bm25f_calculator, 95 std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map, 96 std::unique_ptr<std::unordered_set<ScoringFeatureType>> 97 scoring_feature_types_enabled, 98 double default_score) 99 : score_expression_(std::move(score_expression)), 100 additional_score_expressions_(std::move(additional_score_expressions)), 101 section_weights_(std::move(section_weights)), 102 bm25f_calculator_(std::move(bm25f_calculator)), 103 alias_schema_type_map_(std::move(alias_schema_type_map)), 104 scoring_feature_types_enabled_( 105 std::move(scoring_feature_types_enabled)), 106 default_score_(default_score) { 107 if (is_constant()) { 108 ICING_LOG(WARNING) 109 << "The advanced scoring expression will evaluate to a constant."; 110 } 111 } 112 GetScoreFromExpression(ScoreExpression * expression,const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)113 double GetScoreFromExpression(ScoreExpression* expression, 114 const DocHitInfo& hit_info, 115 const DocHitInfoIterator* query_it) { 116 libtextclassifier3::StatusOr<double> result = 117 expression->EvaluateDouble(hit_info, query_it); 118 if (!result.ok()) { 119 ICING_LOG(ERROR) << "Got an error when scoring a document:\n" 120 << result.status().error_message(); 121 return default_score_; 122 } 123 return std::move(result).ValueOrDie(); 124 } 125 126 std::unique_ptr<ScoreExpression> score_expression_; 127 // Additional score expressions that are used to return extra helpful scores 128 // for clients. 129 std::vector<std::unique_ptr<ScoreExpression>> additional_score_expressions_; 130 std::unique_ptr<SectionWeights> section_weights_; 131 std::unique_ptr<Bm25fCalculator> bm25f_calculator_; 132 std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map_; 133 std::unique_ptr<std::unordered_set<ScoringFeatureType>> 134 scoring_feature_types_enabled_; 135 double default_score_; 136 }; 137 138 } // namespace lib 139 } // namespace icing 140 141 #endif // ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ 142