xref: /aosp_15_r20/external/icing/icing/scoring/advanced_scoring/score-expression.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
16 #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/index/embed/embedding-query-results.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/iterator/doc-hit-info-iterator.h"
32 #include "icing/join/join-children-fetcher.h"
33 #include "icing/schema/schema-store.h"
34 #include "icing/scoring/bm25f-calculator.h"
35 #include "icing/scoring/section-weights.h"
36 #include "icing/store/document-filter-data.h"
37 #include "icing/store/document-id.h"
38 #include "icing/store/document-store.h"
39 #include "icing/util/status-macros.h"
40 
41 namespace icing {
42 namespace lib {
43 
44 enum class ScoreExpressionType {
45   kDouble,
46   kDoubleList,
47   kDocument,  // Only "this" is considered as document type.
48   // TODO(b/326656531): Instead of creating a vector index type, consider
49   // changing it to vector type so that the data is the vector directly.
50   kVectorIndex,
51   kString,
52 };
53 
54 // A map from alias schema type to a set of Icing schema types.
55 using SchemaTypeAliasMap =
56     std::unordered_map<std::string, std::unordered_set<std::string>>;
57 
58 class ScoreExpression {
59  public:
60   virtual ~ScoreExpression() = default;
61 
62   // Evaluate the score expression to double with the current document.
63   //
64   // RETURNS:
65   //   - The evaluated result as a double on success.
66   //   - INVALID_ARGUMENT if a non-finite value is reached while evaluating the
67   //                      expression.
68   //   - INTERNAL if there are inconsistencies.
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)69   virtual libtextclassifier3::StatusOr<double> EvaluateDouble(
70       const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
71     if (type() == ScoreExpressionType::kDouble) {
72       return absl_ports::UnimplementedError(
73           "All ScoreExpressions of type double must provide their own "
74           "implementation of EvaluateDouble!");
75     }
76     return absl_ports::InternalError(
77         "Runtime type error: the expression should never be evaluated to a "
78         "double. There must be inconsistencies in the static type checking.");
79   }
80 
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)81   virtual libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
82       const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
83     if (type() == ScoreExpressionType::kDoubleList) {
84       return absl_ports::UnimplementedError(
85           "All ScoreExpressions of type double List must provide their own "
86           "implementation of EvaluateList!");
87     }
88     return absl_ports::InternalError(
89         "Runtime type error: the expression should never be evaluated to a "
90         "double list. There must be inconsistencies in the static type "
91         "checking.");
92   }
93 
EvaluateString()94   virtual libtextclassifier3::StatusOr<std::string_view> EvaluateString()
95       const {
96     if (type() == ScoreExpressionType::kString) {
97       return absl_ports::UnimplementedError(
98           "All ScoreExpressions of type string must provide their own "
99           "implementation of EvaluateString!");
100     }
101     return absl_ports::InternalError(
102         "Runtime type error: the expression should never be evaluated to a "
103         "string. There must be inconsistencies in the static type checking.");
104   }
105 
106   // Indicate the type to which the current expression will be evaluated.
107   virtual ScoreExpressionType type() const = 0;
108 
109   // Indicate whether the current expression is a constant.
110   // Returns true if and only if the object is of ConstantScoreExpression or
111   // StringExpression type.
is_constant()112   virtual bool is_constant() const { return false; }
113 };
114 
115 class ThisExpression : public ScoreExpression {
116  public:
Create()117   static std::unique_ptr<ThisExpression> Create() {
118     return std::unique_ptr<ThisExpression>(new ThisExpression());
119   }
120 
type()121   ScoreExpressionType type() const override {
122     return ScoreExpressionType::kDocument;
123   }
124 
125  private:
126   ThisExpression() = default;
127 };
128 
129 class ConstantScoreExpression : public ScoreExpression {
130  public:
131   static std::unique_ptr<ConstantScoreExpression> Create(
132       double c, ScoreExpressionType type = ScoreExpressionType::kDouble) {
133     return std::unique_ptr<ConstantScoreExpression>(
134         new ConstantScoreExpression(c, type));
135   }
136 
EvaluateDouble(const DocHitInfo &,const DocHitInfoIterator *)137   libtextclassifier3::StatusOr<double> EvaluateDouble(
138       const DocHitInfo&, const DocHitInfoIterator*) const override {
139     return c_;
140   }
141 
type()142   ScoreExpressionType type() const override { return type_; }
143 
is_constant()144   bool is_constant() const override { return true; }
145 
146  private:
ConstantScoreExpression(double c,ScoreExpressionType type)147   explicit ConstantScoreExpression(double c, ScoreExpressionType type)
148       : c_(c), type_(type) {}
149 
150   double c_;
151   ScoreExpressionType type_;
152 };
153 
154 class StringExpression : public ScoreExpression {
155  public:
Create(std::string str)156   static std::unique_ptr<StringExpression> Create(std::string str) {
157     return std::unique_ptr<StringExpression>(
158         new StringExpression(std::move(str)));
159   }
160 
EvaluateString()161   libtextclassifier3::StatusOr<std::string_view> EvaluateString()
162       const override {
163     return str_;
164   }
165 
type()166   ScoreExpressionType type() const override {
167     return ScoreExpressionType::kString;
168   }
169 
is_constant()170   bool is_constant() const override { return true; }
171 
172  private:
StringExpression(std::string str)173   explicit StringExpression(std::string str) : str_(std::move(str)) {}
174   std::string str_;
175 };
176 
177 class OperatorScoreExpression : public ScoreExpression {
178  public:
179   enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv };
180 
181   // RETURNS:
182   //   - An OperatorScoreExpression instance on success if not simplifiable.
183   //   - A ConstantScoreExpression instance on success if simplifiable.
184   //   - FAILED_PRECONDITION on any null pointer in children.
185   //   - INVALID_ARGUMENT on type errors.
186   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
187       OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children);
188 
189   libtextclassifier3::StatusOr<double> EvaluateDouble(
190       const DocHitInfo& hit_info,
191       const DocHitInfoIterator* query_it) const override;
192 
type()193   ScoreExpressionType type() const override {
194     return ScoreExpressionType::kDouble;
195   }
196 
197  private:
OperatorScoreExpression(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)198   explicit OperatorScoreExpression(
199       OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children)
200       : op_(op), children_(std::move(children)) {}
201 
202   OperatorType op_;
203   std::vector<std::unique_ptr<ScoreExpression>> children_;
204 };
205 
206 class MathFunctionScoreExpression : public ScoreExpression {
207  public:
208   enum class FunctionType {
209     kLog,
210     kPow,
211     kMax,
212     kMin,
213     kLen,
214     kSum,
215     kAvg,
216     kSqrt,
217     kAbs,
218     kSin,
219     kCos,
220     kTan,
221     kMaxOrDefault,
222     kMinOrDefault,
223   };
224 
225   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
226 
227   static const std::unordered_set<FunctionType> kVariableArgumentsFunctions;
228 
229   static const std::unordered_set<FunctionType> kListArgumentFunctions;
230 
231   // RETURNS:
232   //   - A MathFunctionScoreExpression instance on success if not simplifiable.
233   //   - A ConstantScoreExpression instance on success if simplifiable.
234   //   - FAILED_PRECONDITION on any null pointer in args.
235   //   - INVALID_ARGUMENT on type errors.
236   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
237       FunctionType function_type,
238       std::vector<std::unique_ptr<ScoreExpression>> args);
239 
240   libtextclassifier3::StatusOr<double> EvaluateDouble(
241       const DocHitInfo& hit_info,
242       const DocHitInfoIterator* query_it) const override;
243 
type()244   ScoreExpressionType type() const override {
245     return ScoreExpressionType::kDouble;
246   }
247 
248  private:
MathFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)249   explicit MathFunctionScoreExpression(
250       FunctionType function_type,
251       std::vector<std::unique_ptr<ScoreExpression>> args)
252       : function_type_(function_type), args_(std::move(args)) {}
253 
254   FunctionType function_type_;
255   std::vector<std::unique_ptr<ScoreExpression>> args_;
256 };
257 
258 class ListOperationFunctionScoreExpression : public ScoreExpression {
259  public:
260   enum class FunctionType { kFilterByRange };
261 
262   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
263 
264   // RETURNS:
265   //   - A ListOperationFunctionScoreExpression instance on success.
266   //   - FAILED_PRECONDITION on any null pointer in args.
267   //   - INVALID_ARGUMENT on type errors.
268   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
269       FunctionType function_type,
270       std::vector<std::unique_ptr<ScoreExpression>> args);
271 
272   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
273       const DocHitInfo& hit_info,
274       const DocHitInfoIterator* query_it) const override;
275 
type()276   ScoreExpressionType type() const override {
277     return ScoreExpressionType::kDoubleList;
278   }
279 
280  private:
ListOperationFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)281   explicit ListOperationFunctionScoreExpression(
282       FunctionType function_type,
283       std::vector<std::unique_ptr<ScoreExpression>> args)
284       : function_type_(function_type), args_(std::move(args)) {}
285 
286   FunctionType function_type_;
287   std::vector<std::unique_ptr<ScoreExpression>> args_;
288 };
289 
290 class DocumentFunctionScoreExpression : public ScoreExpression {
291  public:
292   enum class FunctionType {
293     kDocumentScore,
294     kCreationTimestamp,
295     kUsageCount,
296     kUsageLastUsedTimestamp,
297   };
298 
299   static const std::unordered_map<std::string, FunctionType> kFunctionNames;
300 
301   // RETURNS:
302   //   - A DocumentFunctionScoreExpression instance on success.
303   //   - FAILED_PRECONDITION on any null pointer in args.
304   //   - INVALID_ARGUMENT on type errors.
305   static libtextclassifier3::StatusOr<
306       std::unique_ptr<DocumentFunctionScoreExpression>>
307   Create(FunctionType function_type,
308          std::vector<std::unique_ptr<ScoreExpression>> args,
309          const DocumentStore* document_store, double default_score,
310          int64_t current_time_ms);
311 
312   libtextclassifier3::StatusOr<double> EvaluateDouble(
313       const DocHitInfo& hit_info,
314       const DocHitInfoIterator* query_it) const override;
315 
type()316   ScoreExpressionType type() const override {
317     return ScoreExpressionType::kDouble;
318   }
319 
320  private:
DocumentFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)321   explicit DocumentFunctionScoreExpression(
322       FunctionType function_type,
323       std::vector<std::unique_ptr<ScoreExpression>> args,
324       const DocumentStore* document_store, double default_score,
325       int64_t current_time_ms)
326       : args_(std::move(args)),
327         document_store_(*document_store),
328         default_score_(default_score),
329         function_type_(function_type),
330         current_time_ms_(current_time_ms) {}
331 
332   std::vector<std::unique_ptr<ScoreExpression>> args_;
333   const DocumentStore& document_store_;
334   double default_score_;
335   FunctionType function_type_;
336   int64_t current_time_ms_;
337 };
338 
339 class RelevanceScoreFunctionScoreExpression : public ScoreExpression {
340  public:
341   static constexpr std::string_view kFunctionName = "relevanceScore";
342 
343   // RETURNS:
344   //   - A RelevanceScoreFunctionScoreExpression instance on success.
345   //   - FAILED_PRECONDITION on any null pointer in args.
346   //   - INVALID_ARGUMENT on type errors.
347   static libtextclassifier3::StatusOr<
348       std::unique_ptr<RelevanceScoreFunctionScoreExpression>>
349   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
350          Bm25fCalculator* bm25f_calculator, double default_score);
351 
352   libtextclassifier3::StatusOr<double> EvaluateDouble(
353       const DocHitInfo& hit_info,
354       const DocHitInfoIterator* query_it) const override;
355 
type()356   ScoreExpressionType type() const override {
357     return ScoreExpressionType::kDouble;
358   }
359 
360  private:
RelevanceScoreFunctionScoreExpression(Bm25fCalculator * bm25f_calculator,double default_score)361   explicit RelevanceScoreFunctionScoreExpression(
362       Bm25fCalculator* bm25f_calculator, double default_score)
363       : bm25f_calculator_(*bm25f_calculator), default_score_(default_score) {}
364 
365   Bm25fCalculator& bm25f_calculator_;
366   double default_score_;
367 };
368 
369 class ChildrenRankingSignalsFunctionScoreExpression : public ScoreExpression {
370  public:
371   static constexpr std::string_view kFunctionName = "childrenRankingSignals";
372 
373   // RETURNS:
374   //   - A ChildrenRankingSignalsFunctionScoreExpression instance on success.
375   //   - FAILED_PRECONDITION on any null pointer in children.
376   //   - INVALID_ARGUMENT on type errors.
377   static libtextclassifier3::StatusOr<
378       std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>>
379   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
380          const DocumentStore& document_store,
381          const JoinChildrenFetcher* join_children_fetcher,
382          int64_t current_time_ms);
383 
384   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
385       const DocHitInfo& hit_info,
386       const DocHitInfoIterator* query_it) const override;
387 
type()388   ScoreExpressionType type() const override {
389     return ScoreExpressionType::kDoubleList;
390   }
391 
392  private:
ChildrenRankingSignalsFunctionScoreExpression(const DocumentStore & document_store,const JoinChildrenFetcher & join_children_fetcher,int64_t current_time_ms)393   explicit ChildrenRankingSignalsFunctionScoreExpression(
394       const DocumentStore& document_store,
395       const JoinChildrenFetcher& join_children_fetcher, int64_t current_time_ms)
396       : document_store_(document_store),
397         join_children_fetcher_(join_children_fetcher),
398         current_time_ms_(current_time_ms) {}
399 
400   const DocumentStore& document_store_;               // Does not own.
401   const JoinChildrenFetcher& join_children_fetcher_;  // Does not own.
402   int64_t current_time_ms_;
403 };
404 
405 class PropertyWeightsFunctionScoreExpression : public ScoreExpression {
406  public:
407   static constexpr std::string_view kFunctionName = "propertyWeights";
408 
409   // RETURNS:
410   //   - A PropertyWeightsFunctionScoreExpression instance on success.
411   //   - FAILED_PRECONDITION on any null pointer in children.
412   //   - INVALID_ARGUMENT on type errors.
413   static libtextclassifier3::StatusOr<
414       std::unique_ptr<PropertyWeightsFunctionScoreExpression>>
415   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
416          const DocumentStore* document_store,
417          const SectionWeights* section_weights, int64_t current_time_ms);
418 
419   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
420       const DocHitInfo& hit_info, const DocHitInfoIterator*) const override;
421 
type()422   ScoreExpressionType type() const override {
423     return ScoreExpressionType::kDoubleList;
424   }
425 
426  private:
PropertyWeightsFunctionScoreExpression(const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)427   explicit PropertyWeightsFunctionScoreExpression(
428       const DocumentStore* document_store,
429       const SectionWeights* section_weights, int64_t current_time_ms)
430       : document_store_(*document_store),
431         section_weights_(*section_weights),
432         current_time_ms_(current_time_ms) {}
433   const DocumentStore& document_store_;
434   const SectionWeights& section_weights_;
435   int64_t current_time_ms_;
436 };
437 
438 class GetEmbeddingParameterFunctionScoreExpression : public ScoreExpression {
439  public:
440   static constexpr std::string_view kFunctionName = "getEmbeddingParameter";
441 
442   // RETURNS:
443   //   - A GetEmbeddingParameterFunctionScoreExpression instance on success if
444   //     not simplifiable.
445   //   - A ConstantScoreExpression instance on success if simplifiable.
446   //   - FAILED_PRECONDITION on any null pointer in children.
447   //   - INVALID_ARGUMENT on type errors.
448   static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
449       std::vector<std::unique_ptr<ScoreExpression>> args);
450 
451   libtextclassifier3::StatusOr<double> EvaluateDouble(
452       const DocHitInfo& hit_info,
453       const DocHitInfoIterator* query_it) const override;
454 
type()455   ScoreExpressionType type() const override {
456     return ScoreExpressionType::kVectorIndex;
457   }
458 
459  private:
GetEmbeddingParameterFunctionScoreExpression(std::unique_ptr<ScoreExpression> arg)460   explicit GetEmbeddingParameterFunctionScoreExpression(
461       std::unique_ptr<ScoreExpression> arg)
462       : arg_(std::move(arg)) {}
463   std::unique_ptr<ScoreExpression> arg_;
464 };
465 
466 class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression {
467  public:
468   static constexpr std::string_view kFunctionName = "matchedSemanticScores";
469 
470   // RETURNS:
471   //   - A MatchedSemanticScoresFunctionScoreExpression instance on success.
472   //   - FAILED_PRECONDITION on any null pointer in children.
473   //   - INVALID_ARGUMENT on type errors.
474   static libtextclassifier3::StatusOr<
475       std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
476   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
477          SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
478          const EmbeddingQueryResults* embedding_query_results);
479 
480   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
481       const DocHitInfo& hit_info,
482       const DocHitInfoIterator* query_it) const override;
483 
type()484   ScoreExpressionType type() const override {
485     return ScoreExpressionType::kDoubleList;
486   }
487 
488  private:
MatchedSemanticScoresFunctionScoreExpression(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,const EmbeddingQueryResults & embedding_query_results)489   explicit MatchedSemanticScoresFunctionScoreExpression(
490       std::vector<std::unique_ptr<ScoreExpression>> args,
491       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
492       const EmbeddingQueryResults& embedding_query_results)
493       : args_(std::move(args)),
494         metric_type_(metric_type),
495         embedding_query_results_(embedding_query_results) {}
496 
497   std::vector<std::unique_ptr<ScoreExpression>> args_;
498   const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
499   const EmbeddingQueryResults& embedding_query_results_;
500 };
501 
502 class GetScorablePropertyFunctionScoreExpression : public ScoreExpression {
503  public:
504   static constexpr std::string_view kFunctionName = "getScorableProperty";
505 
506   // Returns:
507   //   - FAILED_PRECONDITION on any null pointer in children.
508   //   - INVALID_ARGUMENT on
509   //     - |args| type errors.
510   //     - alias_schema_type in the scoring expression does not match to any
511   //       schema type in the |schema_type_alias_map|.
512   //     - any matched schema type not having the specified property_path as a
513   //       scorable property.
514   static libtextclassifier3::StatusOr<
515       std::unique_ptr<GetScorablePropertyFunctionScoreExpression>>
516   Create(std::vector<std::unique_ptr<ScoreExpression>> args,
517          const DocumentStore* document_store, const SchemaStore* schema_store,
518          const SchemaTypeAliasMap& schema_type_alias_map,
519          int64_t current_time_ms);
520 
type()521   ScoreExpressionType type() const override {
522     return ScoreExpressionType::kDoubleList;
523   }
524 
525   libtextclassifier3::StatusOr<std::vector<double>> EvaluateList(
526       const DocHitInfo& hit_info,
527       const DocHitInfoIterator* query_it) const override;
528 
529  private:
530   explicit GetScorablePropertyFunctionScoreExpression(
531       const DocumentStore* document_store, const SchemaStore* schema_store,
532       int64_t current_time_ms,
533       std::unordered_set<SchemaTypeId>&& schema_type_ids,
534       std::string_view property_path);
535 
536   // Returns a set of schema type ids that are matched to the
537   // |alias_schema_type| in the scoring expression, based on the
538   // |schema_type_alias_map|.
539   //
540   // For each of the schema type in the returned set, this function also
541   // validates that:
542   //   - The schema type is valid in the schema store.
543   //   - The |property_path| is defined as scorable under the schema type.
544   static libtextclassifier3::StatusOr<std::unordered_set<SchemaTypeId>>
545   GetAndValidateSchemaTypeIds(std::string_view alias_schema_type,
546                               std::string_view property_path,
547                               const SchemaTypeAliasMap& schema_type_alias_map,
548                               const SchemaStore& schema_store);
549 
550   const DocumentStore& document_store_;
551   const SchemaStore& schema_store_;
552   int64_t current_time_ms_;
553   // A doc hit is evaluated by this function only if its schema type id is in
554   // this set.
555   std::unordered_set<SchemaTypeId> schema_type_ids_;
556   std::string property_path_;
557 };
558 
559 }  // namespace lib
560 }  // namespace icing
561 
562 #endif  // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
563