xref: /aosp_15_r20/external/icing/icing/scoring/advanced_scoring/advanced-scorer.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
16 #define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "icing/text_classifier/lib3/utils/base/status.h"
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/feature-flags.h"
29 #include "icing/index/embed/embedding-query-results.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/iterator/doc-hit-info-iterator.h"
32 #include "icing/join/join-children-fetcher.h"
33 #include "icing/proto/scoring.pb.h"
34 #include "icing/schema/schema-store.h"
35 #include "icing/scoring/advanced_scoring/score-expression.h"
36 #include "icing/scoring/bm25f-calculator.h"
37 #include "icing/scoring/scorer.h"
38 #include "icing/scoring/section-weights.h"
39 #include "icing/store/document-store.h"
40 #include "icing/util/logging.h"
41 
42 namespace icing {
43 namespace lib {
44 
45 class AdvancedScorer : public Scorer {
46  public:
47   // Returns:
48   //   A AdvancedScorer instance on success
49   //   FAILED_PRECONDITION on any null pointer input
50   //   INVALID_ARGUMENT if fails to create an instance
51   static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create(
52       const ScoringSpecProto& scoring_spec, double default_score,
53       SearchSpecProto::EmbeddingQueryMetricType::Code
54           default_semantic_metric_type,
55       const DocumentStore* document_store, const SchemaStore* schema_store,
56       int64_t current_time_ms, const JoinChildrenFetcher* join_children_fetcher,
57       const EmbeddingQueryResults* embedding_query_results,
58       const FeatureFlags* feature_flags);
59 
GetScore(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)60   double GetScore(const DocHitInfo& hit_info,
61                   const DocHitInfoIterator* query_it) override {
62     return GetScoreFromExpression(score_expression_.get(), hit_info, query_it);
63   }
64 
GetAdditionalScores(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)65   std::vector<double> GetAdditionalScores(
66       const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override {
67     std::vector<double> additional_scores;
68     additional_scores.reserve(additional_score_expressions_.size());
69     for (const auto& additional_score_expression :
70          additional_score_expressions_) {
71       additional_scores.push_back(GetScoreFromExpression(
72           additional_score_expression.get(), hit_info, query_it));
73     }
74     return additional_scores;
75   }
76 
PrepareToScore(std::unordered_map<std::string,std::unique_ptr<DocHitInfoIterator>> * query_term_iterators)77   void PrepareToScore(
78       std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>*
79           query_term_iterators) override {
80     if (query_term_iterators == nullptr || query_term_iterators->empty()) {
81       return;
82     }
83     bm25f_calculator_->PrepareToScore(query_term_iterators);
84   }
85 
is_constant()86   bool is_constant() const { return score_expression_->is_constant(); }
87 
88  private:
AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,std::vector<std::unique_ptr<ScoreExpression>> additional_score_expressions,std::unique_ptr<SectionWeights> section_weights,std::unique_ptr<Bm25fCalculator> bm25f_calculator,std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map,std::unique_ptr<std::unordered_set<ScoringFeatureType>> scoring_feature_types_enabled,double default_score)89   explicit AdvancedScorer(
90       std::unique_ptr<ScoreExpression> score_expression,
91       std::vector<std::unique_ptr<ScoreExpression>>
92           additional_score_expressions,
93       std::unique_ptr<SectionWeights> section_weights,
94       std::unique_ptr<Bm25fCalculator> bm25f_calculator,
95       std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map,
96       std::unique_ptr<std::unordered_set<ScoringFeatureType>>
97           scoring_feature_types_enabled,
98       double default_score)
99       : score_expression_(std::move(score_expression)),
100         additional_score_expressions_(std::move(additional_score_expressions)),
101         section_weights_(std::move(section_weights)),
102         bm25f_calculator_(std::move(bm25f_calculator)),
103         alias_schema_type_map_(std::move(alias_schema_type_map)),
104         scoring_feature_types_enabled_(
105             std::move(scoring_feature_types_enabled)),
106         default_score_(default_score) {
107     if (is_constant()) {
108       ICING_LOG(WARNING)
109           << "The advanced scoring expression will evaluate to a constant.";
110     }
111   }
112 
GetScoreFromExpression(ScoreExpression * expression,const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)113   double GetScoreFromExpression(ScoreExpression* expression,
114                                 const DocHitInfo& hit_info,
115                                 const DocHitInfoIterator* query_it) {
116     libtextclassifier3::StatusOr<double> result =
117         expression->EvaluateDouble(hit_info, query_it);
118     if (!result.ok()) {
119       ICING_LOG(ERROR) << "Got an error when scoring a document:\n"
120                        << result.status().error_message();
121       return default_score_;
122     }
123     return std::move(result).ValueOrDie();
124   }
125 
126   std::unique_ptr<ScoreExpression> score_expression_;
127   // Additional score expressions that are used to return extra helpful scores
128   // for clients.
129   std::vector<std::unique_ptr<ScoreExpression>> additional_score_expressions_;
130   std::unique_ptr<SectionWeights> section_weights_;
131   std::unique_ptr<Bm25fCalculator> bm25f_calculator_;
132   std::unique_ptr<SchemaTypeAliasMap> alias_schema_type_map_;
133   std::unique_ptr<std::unordered_set<ScoringFeatureType>>
134       scoring_feature_types_enabled_;
135   double default_score_;
136 };
137 
138 }  // namespace lib
139 }  // namespace icing
140 
141 #endif  // ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
142