xref: /aosp_15_r20/external/icing/icing/scoring/advanced_scoring/score-expression.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/scoring/advanced_scoring/score-expression.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <string_view>
27 #include <unordered_map>
28 #include <unordered_set>
29 #include <utility>
30 #include <vector>
31 
32 #include "icing/text_classifier/lib3/utils/base/status.h"
33 #include "icing/text_classifier/lib3/utils/base/statusor.h"
34 #include "icing/absl_ports/canonical_errors.h"
35 #include "icing/absl_ports/str_cat.h"
36 #include "icing/index/embed/embedding-query-results.h"
37 #include "icing/index/hit/doc-hit-info.h"
38 #include "icing/index/iterator/doc-hit-info-iterator.h"
39 #include "icing/join/join-children-fetcher.h"
40 #include "icing/legacy/core/icing-string-util.h"
41 #include "icing/proto/internal/scorable_property_set.pb.h"
42 #include "icing/schema/schema-store.h"
43 #include "icing/schema/section.h"
44 #include "icing/scoring/bm25f-calculator.h"
45 #include "icing/scoring/scored-document-hit.h"
46 #include "icing/scoring/section-weights.h"
47 #include "icing/store/document-associated-score-data.h"
48 #include "icing/store/document-filter-data.h"
49 #include "icing/store/document-id.h"
50 #include "icing/store/document-store.h"
51 #include "icing/util/embedding-util.h"
52 #include "icing/util/logging.h"
53 #include "icing/util/scorable_property_set.h"
54 #include "icing/util/status-macros.h"
55 
56 namespace icing {
57 namespace lib {
58 
59 namespace {
60 
CheckChildrenNotNull(const std::vector<std::unique_ptr<ScoreExpression>> & children)61 libtextclassifier3::Status CheckChildrenNotNull(
62     const std::vector<std::unique_ptr<ScoreExpression>>& children) {
63   for (const auto& child : children) {
64     ICING_RETURN_ERROR_IF_NULL(child);
65   }
66   return libtextclassifier3::Status::OK;
67 }
68 
GetSchemaTypeId(DocumentId document_id,const DocumentStore & document_store,int64_t current_time_ms)69 SchemaTypeId GetSchemaTypeId(DocumentId document_id,
70                              const DocumentStore& document_store,
71                              int64_t current_time_ms) {
72   auto filter_data_optional =
73       document_store.GetAliveDocumentFilterData(document_id, current_time_ms);
74   if (!filter_data_optional) {
75     // This should never happen. The only failure case for
76     // GetAliveDocumentFilterData is if the document_id is outside of the range
77     // of allocated document_ids, which shouldn't be possible since we're
78     // getting this document_id from the posting lists.
79     ICING_LOG(WARNING) << "No document filter data for document ["
80                        << document_id << "]";
81     return kInvalidSchemaTypeId;
82   }
83   return filter_data_optional.value().schema_type_id();
84 }
85 
86 }  // namespace
87 
88 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)89 OperatorScoreExpression::Create(
90     OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) {
91   if (children.empty()) {
92     return absl_ports::InvalidArgumentError(
93         "OperatorScoreExpression must have at least one argument.");
94   }
95   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(children));
96 
97   bool children_all_constant_double = true;
98   for (const auto& child : children) {
99     if (child->type() != ScoreExpressionType::kDouble) {
100       return absl_ports::InvalidArgumentError(
101           "Operators are only supported for double type.");
102     }
103     if (!child->is_constant()) {
104       children_all_constant_double = false;
105     }
106   }
107   if (op == OperatorType::kNegative) {
108     if (children.size() != 1) {
109       return absl_ports::InvalidArgumentError(
110           "Negative operator must have only 1 argument.");
111     }
112   }
113   std::unique_ptr<ScoreExpression> expression =
114       std::unique_ptr<OperatorScoreExpression>(
115           new OperatorScoreExpression(op, std::move(children)));
116   if (children_all_constant_double) {
117     // Because all of the children are constants, this expression does not
118     // depend on the DocHitInto or query_it that are passed into it.
119     ICING_ASSIGN_OR_RETURN(double constant_value,
120                            expression->EvaluateDouble(DocHitInfo(),
121                                                       /*query_it=*/nullptr));
122     return ConstantScoreExpression::Create(constant_value);
123   }
124   return expression;
125 }
126 
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const127 libtextclassifier3::StatusOr<double> OperatorScoreExpression::EvaluateDouble(
128     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
129   // The Create factory guarantees that an operator will have at least one
130   // child.
131   ICING_ASSIGN_OR_RETURN(double res,
132                          children_.at(0)->EvaluateDouble(hit_info, query_it));
133 
134   if (op_ == OperatorType::kNegative) {
135     return -res;
136   }
137 
138   for (int i = 1; i < children_.size(); ++i) {
139     ICING_ASSIGN_OR_RETURN(double v,
140                            children_.at(i)->EvaluateDouble(hit_info, query_it));
141     switch (op_) {
142       case OperatorType::kPlus:
143         res += v;
144         break;
145       case OperatorType::kMinus:
146         res -= v;
147         break;
148       case OperatorType::kTimes:
149         res *= v;
150         break;
151       case OperatorType::kDiv:
152         res /= v;
153         break;
154       case OperatorType::kNegative:
155         return absl_ports::InternalError("Should never reach here.");
156     }
157     if (!std::isfinite(res)) {
158       return absl_ports::InvalidArgumentError(
159           "Got a non-finite value while evaluating operator score expression.");
160     }
161   }
162   return res;
163 }
164 
165 const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType>
166     MathFunctionScoreExpression::kFunctionNames = {
167         {"log", FunctionType::kLog},
168         {"pow", FunctionType::kPow},
169         {"max", FunctionType::kMax},
170         {"min", FunctionType::kMin},
171         {"len", FunctionType::kLen},
172         {"sum", FunctionType::kSum},
173         {"avg", FunctionType::kAvg},
174         {"sqrt", FunctionType::kSqrt},
175         {"abs", FunctionType::kAbs},
176         {"sin", FunctionType::kSin},
177         {"cos", FunctionType::kCos},
178         {"tan", FunctionType::kTan},
179         {"maxOrDefault", FunctionType::kMaxOrDefault},
180         {"minOrDefault", FunctionType::kMinOrDefault}};
181 
182 const std::unordered_set<MathFunctionScoreExpression::FunctionType>
183     MathFunctionScoreExpression::kVariableArgumentsFunctions = {
184         FunctionType::kMax, FunctionType::kMin, FunctionType::kLen,
185         FunctionType::kSum, FunctionType::kAvg};
186 
187 const std::unordered_set<MathFunctionScoreExpression::FunctionType>
188     MathFunctionScoreExpression::kListArgumentFunctions = {
189         FunctionType::kMaxOrDefault, FunctionType::kMinOrDefault};
190 
191 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)192 MathFunctionScoreExpression::Create(
193     FunctionType function_type,
194     std::vector<std::unique_ptr<ScoreExpression>> args) {
195   if (args.empty()) {
196     return absl_ports::InvalidArgumentError(
197         "Math functions must have at least one argument.");
198   }
199   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
200 
201   // Return early for functions that support variable length arguments and the
202   // first argument is a list.
203   if (args.size() == 1 && args[0]->type() == ScoreExpressionType::kDoubleList &&
204       kVariableArgumentsFunctions.count(function_type) > 0) {
205     return std::unique_ptr<MathFunctionScoreExpression>(
206         new MathFunctionScoreExpression(function_type, std::move(args)));
207   }
208 
209   bool args_all_constant_double = false;
210   if (kListArgumentFunctions.count(function_type) > 0) {
211     if (args[0]->type() != ScoreExpressionType::kDoubleList) {
212       return absl_ports::InvalidArgumentError(
213           "Got an invalid type for the math function. Should expect a list "
214           "type value in the first argument.");
215     }
216   } else {
217     args_all_constant_double = true;
218     for (const auto& child : args) {
219       if (child->type() != ScoreExpressionType::kDouble) {
220         return absl_ports::InvalidArgumentError(
221             "Got an invalid type for the math function. Should expect a double "
222             "type argument.");
223       }
224       if (!child->is_constant()) {
225         args_all_constant_double = false;
226       }
227     }
228   }
229   switch (function_type) {
230     case FunctionType::kLog:
231       if (args.size() != 1 && args.size() != 2) {
232         return absl_ports::InvalidArgumentError(
233             "log must have 1 or 2 arguments.");
234       }
235       break;
236     case FunctionType::kPow:
237       if (args.size() != 2) {
238         return absl_ports::InvalidArgumentError("pow must have 2 arguments.");
239       }
240       break;
241     case FunctionType::kSqrt:
242       if (args.size() != 1) {
243         return absl_ports::InvalidArgumentError("sqrt must have 1 argument.");
244       }
245       break;
246     case FunctionType::kAbs:
247       if (args.size() != 1) {
248         return absl_ports::InvalidArgumentError("abs must have 1 argument.");
249       }
250       break;
251     case FunctionType::kSin:
252       if (args.size() != 1) {
253         return absl_ports::InvalidArgumentError("sin must have 1 argument.");
254       }
255       break;
256     case FunctionType::kCos:
257       if (args.size() != 1) {
258         return absl_ports::InvalidArgumentError("cos must have 1 argument.");
259       }
260       break;
261     case FunctionType::kTan:
262       if (args.size() != 1) {
263         return absl_ports::InvalidArgumentError("tan must have 1 argument.");
264       }
265       break;
266     case FunctionType::kMaxOrDefault:
267       if (args.size() != 2) {
268         return absl_ports::InvalidArgumentError(
269             "maxOrDefault must have 2 arguments.");
270       }
271       if (args[1]->type() != ScoreExpressionType::kDouble) {
272         return absl_ports::InvalidArgumentError(
273             "maxOrDefault must have a double type as the second argument.");
274       }
275       break;
276     case FunctionType::kMinOrDefault:
277       if (args.size() != 2) {
278         return absl_ports::InvalidArgumentError(
279             "minOrDefault must have 2 arguments.");
280       }
281       if (args[1]->type() != ScoreExpressionType::kDouble) {
282         return absl_ports::InvalidArgumentError(
283             "minOrDefault must have a double type as the second argument.");
284       }
285       break;
286     // Functions that support variable length arguments
287     case FunctionType::kMax:
288       [[fallthrough]];
289     case FunctionType::kMin:
290       [[fallthrough]];
291     case FunctionType::kLen:
292       [[fallthrough]];
293     case FunctionType::kSum:
294       [[fallthrough]];
295     case FunctionType::kAvg:
296       break;
297   }
298   std::unique_ptr<ScoreExpression> expression =
299       std::unique_ptr<MathFunctionScoreExpression>(
300           new MathFunctionScoreExpression(function_type, std::move(args)));
301   if (args_all_constant_double) {
302     // Because all of the arguments are constants, this expression does not
303     // depend on the DocHitInto or query_it that are passed into it.
304     ICING_ASSIGN_OR_RETURN(double constant_value,
305                            expression->EvaluateDouble(DocHitInfo(),
306                                                       /*query_it=*/nullptr));
307     return ConstantScoreExpression::Create(constant_value);
308   }
309   return expression;
310 }
311 
312 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const313 MathFunctionScoreExpression::EvaluateDouble(
314     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
315   std::vector<double> values;
316   int ind = 0;
317   if (args_.at(0)->type() == ScoreExpressionType::kDoubleList) {
318     ICING_ASSIGN_OR_RETURN(values,
319                            args_.at(0)->EvaluateList(hit_info, query_it));
320     ind = 1;
321   }
322   for (; ind < args_.size(); ++ind) {
323     ICING_ASSIGN_OR_RETURN(double v,
324                            args_.at(ind)->EvaluateDouble(hit_info, query_it));
325     values.push_back(v);
326   }
327 
328   double res = 0;
329   switch (function_type_) {
330     case FunctionType::kLog:
331       if (values.size() == 1) {
332         res = log(values[0]);
333       } else {
334         // argument 0 is log base
335         // argument 1 is the value
336         res = log(values[1]) / log(values[0]);
337       }
338       break;
339     case FunctionType::kPow:
340       res = pow(values[0], values[1]);
341       break;
342     case FunctionType::kMax:
343       if (values.empty()) {
344         return absl_ports::InvalidArgumentError(
345             "Got an empty parameter set in max function");
346       }
347       res = *std::max_element(values.begin(), values.end());
348       break;
349     case FunctionType::kMin:
350       if (values.empty()) {
351         return absl_ports::InvalidArgumentError(
352             "Got an empty parameter set in min function");
353       }
354       res = *std::min_element(values.begin(), values.end());
355       break;
356     case FunctionType::kLen:
357       res = values.size();
358       break;
359     case FunctionType::kSum:
360       res = std::reduce(values.begin(), values.end());
361       break;
362     case FunctionType::kAvg:
363       if (values.empty()) {
364         return absl_ports::InvalidArgumentError(
365             "Got an empty parameter set in avg function.");
366       }
367       res = std::reduce(values.begin(), values.end()) / values.size();
368       break;
369     case FunctionType::kSqrt:
370       res = sqrt(values[0]);
371       break;
372     case FunctionType::kAbs:
373       res = abs(values[0]);
374       break;
375     case FunctionType::kSin:
376       res = sin(values[0]);
377       break;
378     case FunctionType::kCos:
379       res = cos(values[0]);
380       break;
381     case FunctionType::kTan:
382       res = tan(values[0]);
383       break;
384     // For the following two functions, the last value is the default value.
385     // If values.size() == 1, then it means the provided list is empty.
386     case FunctionType::kMaxOrDefault:
387       if (values.size() == 1) {
388         res = values[0];
389       } else {
390         res = *std::max_element(values.begin(), values.end() - 1);
391       }
392       break;
393     case FunctionType::kMinOrDefault:
394       if (values.size() == 1) {
395         res = values[0];
396       } else {
397         res = *std::min_element(values.begin(), values.end() - 1);
398       }
399       break;
400   }
401   if (!std::isfinite(res)) {
402     return absl_ports::InvalidArgumentError(
403         "Got a non-finite value while evaluating math function score "
404         "expression.");
405   }
406   return res;
407 }
408 
409 const std::unordered_map<std::string,
410                          ListOperationFunctionScoreExpression::FunctionType>
411     ListOperationFunctionScoreExpression::kFunctionNames = {
412         {"filterByRange", FunctionType::kFilterByRange}};
413 
414 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)415 ListOperationFunctionScoreExpression::Create(
416     FunctionType function_type,
417     std::vector<std::unique_ptr<ScoreExpression>> args) {
418   if (args.empty()) {
419     return absl_ports::InvalidArgumentError(
420         "List operation functions must have at least one argument.");
421   }
422   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
423 
424   switch (function_type) {
425     case FunctionType::kFilterByRange:
426       if (args.size() != 3) {
427         return absl_ports::InvalidArgumentError(
428             "filterByRange must have 3 arguments.");
429       }
430       if (args[0]->type() != ScoreExpressionType::kDoubleList) {
431         return absl_ports::InvalidArgumentError(
432             "Should expect a list type value for the first argument of "
433             "filterByRange.");
434       }
435       if (args.at(1)->type() != ScoreExpressionType::kDouble ||
436           args.at(2)->type() != ScoreExpressionType::kDouble) {
437         return absl_ports::InvalidArgumentError(
438             "Should expect double type values for the second and third "
439             "arguments of filterByRange.");
440       }
441       break;
442   }
443   return std::unique_ptr<ListOperationFunctionScoreExpression>(
444       new ListOperationFunctionScoreExpression(function_type, std::move(args)));
445 }
446 
447 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const448 ListOperationFunctionScoreExpression::EvaluateList(
449     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
450   switch (function_type_) {
451     case FunctionType::kFilterByRange:
452       ICING_ASSIGN_OR_RETURN(std::vector<double> list_value,
453                              args_.at(0)->EvaluateList(hit_info, query_it));
454       ICING_ASSIGN_OR_RETURN(double low,
455                              args_.at(1)->EvaluateDouble(hit_info, query_it));
456       ICING_ASSIGN_OR_RETURN(double high,
457                              args_.at(2)->EvaluateDouble(hit_info, query_it));
458       if (low > high) {
459         return absl_ports::InvalidArgumentError(
460             "The lower bound cannot be greater than the upper bound.");
461       }
462       auto new_end =
463           std::remove_if(list_value.begin(), list_value.end(),
464                          [low, high](double v) { return v < low || v > high; });
465       list_value.erase(new_end, list_value.end());
466       return list_value;
467       break;
468   }
469   return absl_ports::InternalError("Should never reach here.");
470 }
471 
472 const std::unordered_map<std::string,
473                          DocumentFunctionScoreExpression::FunctionType>
474     DocumentFunctionScoreExpression::kFunctionNames = {
475         {"documentScore", FunctionType::kDocumentScore},
476         {"creationTimestamp", FunctionType::kCreationTimestamp},
477         {"usageCount", FunctionType::kUsageCount},
478         {"usageLastUsedTimestamp", FunctionType::kUsageLastUsedTimestamp}};
479 
480 libtextclassifier3::StatusOr<std::unique_ptr<DocumentFunctionScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)481 DocumentFunctionScoreExpression::Create(
482     FunctionType function_type,
483     std::vector<std::unique_ptr<ScoreExpression>> args,
484     const DocumentStore* document_store, double default_score,
485     int64_t current_time_ms) {
486   if (args.empty()) {
487     return absl_ports::InvalidArgumentError(
488         "Document-based functions must have at least one argument.");
489   }
490   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
491 
492   if (args[0]->type() != ScoreExpressionType::kDocument) {
493     return absl_ports::InvalidArgumentError(
494         "The first parameter of document-based functions must be \"this\".");
495   }
496   switch (function_type) {
497     case FunctionType::kDocumentScore:
498       [[fallthrough]];
499     case FunctionType::kCreationTimestamp:
500       if (args.size() != 1) {
501         return absl_ports::InvalidArgumentError(
502             "DocumentScore/CreationTimestamp must have 1 argument.");
503       }
504       break;
505     case FunctionType::kUsageCount:
506       [[fallthrough]];
507     case FunctionType::kUsageLastUsedTimestamp:
508       if (args.size() != 2 || args[1]->type() != ScoreExpressionType::kDouble) {
509         return absl_ports::InvalidArgumentError(
510             "UsageCount/UsageLastUsedTimestamp must have 2 arguments. The "
511             "first argument should be \"this\", and the second argument "
512             "should be the usage type.");
513       }
514       break;
515   }
516   return std::unique_ptr<DocumentFunctionScoreExpression>(
517       new DocumentFunctionScoreExpression(function_type, std::move(args),
518                                           document_store, default_score,
519                                           current_time_ms));
520 }
521 
522 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const523 DocumentFunctionScoreExpression::EvaluateDouble(
524     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
525   switch (function_type_) {
526     case FunctionType::kDocumentScore:
527       [[fallthrough]];
528     case FunctionType::kCreationTimestamp: {
529       ICING_ASSIGN_OR_RETURN(DocumentAssociatedScoreData score_data,
530                              document_store_.GetDocumentAssociatedScoreData(
531                                  hit_info.document_id()),
532                              default_score_);
533       if (function_type_ == FunctionType::kDocumentScore) {
534         return static_cast<double>(score_data.document_score());
535       }
536       return static_cast<double>(score_data.creation_timestamp_ms());
537     }
538     case FunctionType::kUsageCount:
539       [[fallthrough]];
540     case FunctionType::kUsageLastUsedTimestamp: {
541       ICING_ASSIGN_OR_RETURN(double raw_usage_type,
542                              args_[1]->EvaluateDouble(hit_info, query_it));
543       int usage_type = (int)raw_usage_type;
544       if (usage_type < 1 || usage_type > 3 || raw_usage_type != usage_type) {
545         return absl_ports::InvalidArgumentError(
546             "Usage type must be an integer from 1 to 3");
547       }
548       std::optional<UsageStore::UsageScores> usage_scores =
549           document_store_.GetUsageScores(hit_info.document_id(),
550                                          current_time_ms_);
551       if (!usage_scores) {
552         // If there's no UsageScores entry present for this doc, then just
553         // treat it as a default instance.
554         usage_scores = UsageStore::UsageScores();
555       }
556       if (function_type_ == FunctionType::kUsageCount) {
557         if (usage_type == 1) {
558           return usage_scores->usage_type1_count;
559         } else if (usage_type == 2) {
560           return usage_scores->usage_type2_count;
561         } else {
562           return usage_scores->usage_type3_count;
563         }
564       }
565       if (usage_type == 1) {
566         return usage_scores->usage_type1_last_used_timestamp_s * 1000.0;
567       } else if (usage_type == 2) {
568         return usage_scores->usage_type2_last_used_timestamp_s * 1000.0;
569       } else {
570         return usage_scores->usage_type3_last_used_timestamp_s * 1000.0;
571       }
572     }
573   }
574 }
575 
576 libtextclassifier3::StatusOr<
577     std::unique_ptr<RelevanceScoreFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,Bm25fCalculator * bm25f_calculator,double default_score)578 RelevanceScoreFunctionScoreExpression::Create(
579     std::vector<std::unique_ptr<ScoreExpression>> args,
580     Bm25fCalculator* bm25f_calculator, double default_score) {
581   if (args.size() != 1) {
582     return absl_ports::InvalidArgumentError(
583         "relevanceScore must have 1 argument.");
584   }
585   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
586 
587   if (args[0]->type() != ScoreExpressionType::kDocument) {
588     return absl_ports::InvalidArgumentError(
589         "relevanceScore must take \"this\" as its argument.");
590   }
591   return std::unique_ptr<RelevanceScoreFunctionScoreExpression>(
592       new RelevanceScoreFunctionScoreExpression(bm25f_calculator,
593                                                 default_score));
594 }
595 
596 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const597 RelevanceScoreFunctionScoreExpression::EvaluateDouble(
598     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
599   if (query_it == nullptr) {
600     return default_score_;
601   }
602   return static_cast<double>(
603       bm25f_calculator_.ComputeScore(query_it, hit_info, default_score_));
604 }
605 
606 libtextclassifier3::StatusOr<
607     std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore & document_store,const JoinChildrenFetcher * join_children_fetcher,int64_t current_time_ms)608 ChildrenRankingSignalsFunctionScoreExpression::Create(
609     std::vector<std::unique_ptr<ScoreExpression>> args,
610     const DocumentStore& document_store,
611     const JoinChildrenFetcher* join_children_fetcher, int64_t current_time_ms) {
612   if (args.size() != 1) {
613     return absl_ports::InvalidArgumentError(
614         "childrenRankingSignals must have 1 argument.");
615   }
616   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
617 
618   if (args[0]->type() != ScoreExpressionType::kDocument) {
619     return absl_ports::InvalidArgumentError(
620         "childrenRankingSignals must take \"this\" as its argument.");
621   }
622   if (join_children_fetcher == nullptr) {
623     return absl_ports::InvalidArgumentError(
624         "childrenRankingSignals must only be used with join, but "
625         "JoinChildrenFetcher is not provided.");
626   }
627   return std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>(
628       new ChildrenRankingSignalsFunctionScoreExpression(
629           document_store, *join_children_fetcher, current_time_ms));
630 }
631 
632 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const633 ChildrenRankingSignalsFunctionScoreExpression::EvaluateList(
634     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
635   ICING_ASSIGN_OR_RETURN(
636       std::vector<ScoredDocumentHit> children_hits,
637       join_children_fetcher_.GetChildren(hit_info.document_id()));
638   std::vector<double> children_scores;
639   children_scores.reserve(children_hits.size());
640   for (const ScoredDocumentHit& child_hit : children_hits) {
641     children_scores.push_back(child_hit.score());
642   }
643   return std::move(children_scores);
644 }
645 
646 libtextclassifier3::StatusOr<
647     std::unique_ptr<PropertyWeightsFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)648 PropertyWeightsFunctionScoreExpression::Create(
649     std::vector<std::unique_ptr<ScoreExpression>> args,
650     const DocumentStore* document_store, const SectionWeights* section_weights,
651     int64_t current_time_ms) {
652   if (args.size() != 1) {
653     return absl_ports::InvalidArgumentError(
654         "propertyWeights must have 1 argument.");
655   }
656   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
657 
658   if (args[0]->type() != ScoreExpressionType::kDocument) {
659     return absl_ports::InvalidArgumentError(
660         "propertyWeights must take \"this\" as its argument.");
661   }
662   return std::unique_ptr<PropertyWeightsFunctionScoreExpression>(
663       new PropertyWeightsFunctionScoreExpression(
664           document_store, section_weights, current_time_ms));
665 }
666 
667 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator *) const668 PropertyWeightsFunctionScoreExpression::EvaluateList(
669     const DocHitInfo& hit_info, const DocHitInfoIterator*) const {
670   std::vector<double> weights;
671   SectionIdMask sections = hit_info.hit_section_ids_mask();
672   SchemaTypeId schema_type_id = GetSchemaTypeId(
673       hit_info.document_id(), document_store_, current_time_ms_);
674 
675   while (sections != 0) {
676     SectionId section_id = __builtin_ctzll(sections);
677     sections &= ~(UINT64_C(1) << section_id);
678     weights.push_back(section_weights_.GetNormalizedSectionWeight(
679         schema_type_id, section_id));
680   }
681   return weights;
682 }
683 
684 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args)685 GetEmbeddingParameterFunctionScoreExpression::Create(
686     std::vector<std::unique_ptr<ScoreExpression>> args) {
687   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
688 
689   if (args.size() != 1) {
690     return absl_ports::InvalidArgumentError(
691         absl_ports::StrCat(kFunctionName, " must have 1 argument."));
692   }
693   if (args[0]->type() != ScoreExpressionType::kDouble) {
694     return absl_ports::InvalidArgumentError(
695         absl_ports::StrCat(kFunctionName, " got invalid argument type."));
696   }
697   bool is_constant = args[0]->is_constant();
698   std::unique_ptr<ScoreExpression> expression =
699       std::unique_ptr<GetEmbeddingParameterFunctionScoreExpression>(
700           new GetEmbeddingParameterFunctionScoreExpression(std::move(args[0])));
701   if (is_constant) {
702     ICING_ASSIGN_OR_RETURN(double constant_value,
703                            expression->EvaluateDouble(DocHitInfo(),
704                                                       /*query_it=*/nullptr));
705     return ConstantScoreExpression::Create(constant_value, expression->type());
706   }
707   return expression;
708 }
709 
710 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const711 GetEmbeddingParameterFunctionScoreExpression::EvaluateDouble(
712     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
713   ICING_ASSIGN_OR_RETURN(double raw_query_index,
714                          arg_->EvaluateDouble(hit_info, query_it));
715   if (raw_query_index < 0) {
716     return absl_ports::InvalidArgumentError(
717         "The index of an embedding query must be a non-negative integer.");
718   }
719   if (raw_query_index > std::numeric_limits<uint32_t>::max()) {
720     return absl_ports::InvalidArgumentError(
721         "The index of an embedding query exceeds the maximum value of uint32.");
722   }
723   uint32_t query_index = (uint32_t)raw_query_index;
724   if (query_index != raw_query_index) {
725     return absl_ports::InvalidArgumentError(
726         "The index of an embedding query must be an integer.");
727   }
728   return query_index;
729 }
730 
731 libtextclassifier3::StatusOr<
732     std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,const EmbeddingQueryResults * embedding_query_results)733 MatchedSemanticScoresFunctionScoreExpression::Create(
734     std::vector<std::unique_ptr<ScoreExpression>> args,
735     SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
736     const EmbeddingQueryResults* embedding_query_results) {
737   ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
738   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
739 
740   if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
741     return absl_ports::InvalidArgumentError(
742         absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
743   }
744   if (args.size() != 2 && args.size() != 3) {
745     return absl_ports::InvalidArgumentError(
746         absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
747   }
748   ScoreExpression* embedding_index_arg = args[1].get();
749   if (embedding_index_arg->type() != ScoreExpressionType::kVectorIndex) {
750     return absl_ports::InvalidArgumentError(absl_ports::StrCat(
751         kFunctionName, " got invalid argument type for embedding vector."));
752   }
753   if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
754     return absl_ports::InvalidArgumentError(
755         "Embedding metric can only be given as a string.");
756   }
757 
758   SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
759       default_metric_type;
760   if (args.size() == 3) {
761     if (!args[2]->is_constant()) {
762       return absl_ports::InvalidArgumentError(
763           "Embedding metric can only be given as a constant string.");
764     }
765     ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->EvaluateString());
766     ICING_ASSIGN_OR_RETURN(
767         metric_type,
768         embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
769   }
770   if (embedding_index_arg->is_constant()) {
771     ICING_ASSIGN_OR_RETURN(
772         uint32_t embedding_index,
773         embedding_index_arg->EvaluateDouble(DocHitInfo(),
774                                             /*query_it=*/nullptr));
775     if (embedding_query_results->GetScoreMap(embedding_index, metric_type) ==
776         nullptr) {
777       return absl_ports::InvalidArgumentError(absl_ports::StrCat(
778           "The embedding query index ", std::to_string(embedding_index),
779           " with metric type ",
780           SearchSpecProto::EmbeddingQueryMetricType::Code_Name(metric_type),
781           " has not been queried."));
782     }
783   }
784   return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
785       new MatchedSemanticScoresFunctionScoreExpression(
786           std::move(args), metric_type, *embedding_query_results));
787 }
788 
789 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const790 MatchedSemanticScoresFunctionScoreExpression::EvaluateList(
791     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
792   ICING_ASSIGN_OR_RETURN(double raw_query_index,
793                          args_[1]->EvaluateDouble(hit_info, query_it));
794   uint32_t query_index = (uint32_t)raw_query_index;
795   const std::vector<double>* scores =
796       embedding_query_results_.GetMatchedScoresForDocument(
797           query_index, metric_type_, hit_info.document_id());
798   if (scores == nullptr) {
799     return std::vector<double>();
800   }
801   return *scores;
802 }
803 
804 GetScorablePropertyFunctionScoreExpression::
GetScorablePropertyFunctionScoreExpression(const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms,std::unordered_set<SchemaTypeId> && schema_type_ids,std::string_view property_path)805     GetScorablePropertyFunctionScoreExpression(
806         const DocumentStore* document_store, const SchemaStore* schema_store,
807         int64_t current_time_ms,
808         std::unordered_set<SchemaTypeId>&& schema_type_ids,
809         std::string_view property_path)
810     : document_store_(*document_store),
811       schema_store_(*schema_store),
812       current_time_ms_(current_time_ms),
813       schema_type_ids_(std::move(schema_type_ids)),
814       property_path_(property_path) {}
815 
816 libtextclassifier3::StatusOr<std::unordered_set<SchemaTypeId>>
GetAndValidateSchemaTypeIds(std::string_view alias_schema_type,std::string_view property_path,const SchemaTypeAliasMap & schema_type_alias_map,const SchemaStore & schema_store)817 GetScorablePropertyFunctionScoreExpression::GetAndValidateSchemaTypeIds(
818     std::string_view alias_schema_type, std::string_view property_path,
819     const SchemaTypeAliasMap& schema_type_alias_map,
820     const SchemaStore& schema_store) {
821   auto alias_map_iter = schema_type_alias_map.find(alias_schema_type.data());
822   if (alias_map_iter == schema_type_alias_map.end()) {
823     return absl_ports::InvalidArgumentError(absl_ports::StrCat(
824         "The alias schema type in the score expression is not found in the "
825         "schema_type_alias_map: ",
826         alias_schema_type));
827   }
828 
829   std::unordered_set<SchemaTypeId> schema_type_ids;
830   for (std::string_view schema_type : alias_map_iter->second) {
831     // First, verify that the schema type has a valid schema type id in the
832     // schema store.
833     libtextclassifier3::StatusOr<SchemaTypeId> schema_type_id_or =
834         schema_store.GetSchemaTypeId(schema_type);
835     if (!schema_type_id_or.ok()) {
836       if (absl_ports::IsNotFound(schema_type_id_or.status())) {
837         // Ignores the schema type if it is not found in the schema store.
838         continue;
839       }
840       return schema_type_id_or.status();
841     }
842     SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie();
843 
844     // Then, calls GetScorablePropertyIndex() here to validate if the property
845     // path is scorable under the schema type, no need to check the returned
846     // index value.
847     libtextclassifier3::StatusOr<std::optional<int>>
848         scorable_property_index_or = schema_store.GetScorablePropertyIndex(
849             schema_type_id, property_path);
850     if (!scorable_property_index_or.ok()) {
851       return scorable_property_index_or.status();
852     }
853     if (!scorable_property_index_or.ValueOrDie().has_value()) {
854       return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
855           "'%s' is not defined as a scorable property under schema type %d",
856           property_path.data(), schema_type_id));
857     }
858     schema_type_ids.insert(schema_type_id);
859   }
860   return schema_type_ids;
861 }
862 
863 libtextclassifier3::StatusOr<
864     std::unique_ptr<GetScorablePropertyFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,const SchemaStore * schema_store,const SchemaTypeAliasMap & schema_type_alias_map,int64_t current_time_ms)865 GetScorablePropertyFunctionScoreExpression::Create(
866     std::vector<std::unique_ptr<ScoreExpression>> args,
867     const DocumentStore* document_store, const SchemaStore* schema_store,
868     const SchemaTypeAliasMap& schema_type_alias_map, int64_t current_time_ms) {
869   ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
870 
871   if (args.size() != 2 || args[0]->type() != ScoreExpressionType::kString ||
872       args[1]->type() != ScoreExpressionType::kString) {
873     return absl_ports::InvalidArgumentError(absl_ports::StrCat(
874         kFunctionName, " must take exactly two string params"));
875   }
876 
877   // Validate schema type.
878   ICING_ASSIGN_OR_RETURN(std::string_view alias_schema_type,
879                          args[0]->EvaluateString());
880   ICING_ASSIGN_OR_RETURN(std::string_view property_path,
881                          args[1]->EvaluateString());
882   ICING_ASSIGN_OR_RETURN(
883       std::unordered_set<SchemaTypeId> schema_type_ids,
884       GetAndValidateSchemaTypeIds(alias_schema_type, property_path,
885                                   schema_type_alias_map, *schema_store));
886 
887   return std::unique_ptr<GetScorablePropertyFunctionScoreExpression>(
888       new GetScorablePropertyFunctionScoreExpression(
889           document_store, schema_store, current_time_ms,
890           std::move(schema_type_ids), property_path));
891 }
892 
893 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const894 GetScorablePropertyFunctionScoreExpression::EvaluateList(
895     const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
896   SchemaTypeId doc_schema_type_id = GetSchemaTypeId(
897       hit_info.document_id(), document_store_, current_time_ms_);
898   if (schema_type_ids_.find(doc_schema_type_id) == schema_type_ids_.end()) {
899     return std::vector<double>();
900   }
901 
902   std::unique_ptr<ScorablePropertySet> scorable_property_set =
903       document_store_.GetScorablePropertySet(hit_info.document_id(),
904                                              current_time_ms_);
905   // It should never happen.
906   if (scorable_property_set == nullptr) {
907     return absl_ports::InternalError(IcingStringUtil::StringPrintf(
908         "Failed to retrieve ScorablePropertySet for document %d",
909         hit_info.document_id()));
910   }
911 
912   const ScorablePropertyProto* scorable_property_proto =
913       scorable_property_set->GetScorablePropertyProto(property_path_);
914   // It should never happen as icing generates a default value for each scorable
915   // property when the document is created.
916   if (scorable_property_proto == nullptr) {
917     return absl_ports::InternalError(IcingStringUtil::StringPrintf(
918         "Failed to retrieve ScorablePropertyProto for document %d, and "
919         "property path %s",
920         hit_info.document_id(), property_path_.c_str()));
921   }
922 
923   // Converts ScorablePropertyProto to a vector of doubles.
924   if (scorable_property_proto->int64_values_size() > 0) {
925     return std::vector<double>(scorable_property_proto->int64_values().begin(),
926                                scorable_property_proto->int64_values().end());
927   } else if (scorable_property_proto->double_values_size() > 0) {
928     return std::vector<double>(scorable_property_proto->double_values().begin(),
929                                scorable_property_proto->double_values().end());
930   } else if (scorable_property_proto->boolean_values_size() > 0) {
931     return std::vector<double>(
932         scorable_property_proto->boolean_values().begin(),
933         scorable_property_proto->boolean_values().end());
934   }
935   return std::vector<double>();
936 }
937 
938 }  // namespace lib
939 }  // namespace icing
940