1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/scoring/advanced_scoring/score-expression.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <string_view>
27 #include <unordered_map>
28 #include <unordered_set>
29 #include <utility>
30 #include <vector>
31
32 #include "icing/text_classifier/lib3/utils/base/status.h"
33 #include "icing/text_classifier/lib3/utils/base/statusor.h"
34 #include "icing/absl_ports/canonical_errors.h"
35 #include "icing/absl_ports/str_cat.h"
36 #include "icing/index/embed/embedding-query-results.h"
37 #include "icing/index/hit/doc-hit-info.h"
38 #include "icing/index/iterator/doc-hit-info-iterator.h"
39 #include "icing/join/join-children-fetcher.h"
40 #include "icing/legacy/core/icing-string-util.h"
41 #include "icing/proto/internal/scorable_property_set.pb.h"
42 #include "icing/schema/schema-store.h"
43 #include "icing/schema/section.h"
44 #include "icing/scoring/bm25f-calculator.h"
45 #include "icing/scoring/scored-document-hit.h"
46 #include "icing/scoring/section-weights.h"
47 #include "icing/store/document-associated-score-data.h"
48 #include "icing/store/document-filter-data.h"
49 #include "icing/store/document-id.h"
50 #include "icing/store/document-store.h"
51 #include "icing/util/embedding-util.h"
52 #include "icing/util/logging.h"
53 #include "icing/util/scorable_property_set.h"
54 #include "icing/util/status-macros.h"
55
56 namespace icing {
57 namespace lib {
58
59 namespace {
60
CheckChildrenNotNull(const std::vector<std::unique_ptr<ScoreExpression>> & children)61 libtextclassifier3::Status CheckChildrenNotNull(
62 const std::vector<std::unique_ptr<ScoreExpression>>& children) {
63 for (const auto& child : children) {
64 ICING_RETURN_ERROR_IF_NULL(child);
65 }
66 return libtextclassifier3::Status::OK;
67 }
68
GetSchemaTypeId(DocumentId document_id,const DocumentStore & document_store,int64_t current_time_ms)69 SchemaTypeId GetSchemaTypeId(DocumentId document_id,
70 const DocumentStore& document_store,
71 int64_t current_time_ms) {
72 auto filter_data_optional =
73 document_store.GetAliveDocumentFilterData(document_id, current_time_ms);
74 if (!filter_data_optional) {
75 // This should never happen. The only failure case for
76 // GetAliveDocumentFilterData is if the document_id is outside of the range
77 // of allocated document_ids, which shouldn't be possible since we're
78 // getting this document_id from the posting lists.
79 ICING_LOG(WARNING) << "No document filter data for document ["
80 << document_id << "]";
81 return kInvalidSchemaTypeId;
82 }
83 return filter_data_optional.value().schema_type_id();
84 }
85
86 } // namespace
87
88 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)89 OperatorScoreExpression::Create(
90 OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) {
91 if (children.empty()) {
92 return absl_ports::InvalidArgumentError(
93 "OperatorScoreExpression must have at least one argument.");
94 }
95 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(children));
96
97 bool children_all_constant_double = true;
98 for (const auto& child : children) {
99 if (child->type() != ScoreExpressionType::kDouble) {
100 return absl_ports::InvalidArgumentError(
101 "Operators are only supported for double type.");
102 }
103 if (!child->is_constant()) {
104 children_all_constant_double = false;
105 }
106 }
107 if (op == OperatorType::kNegative) {
108 if (children.size() != 1) {
109 return absl_ports::InvalidArgumentError(
110 "Negative operator must have only 1 argument.");
111 }
112 }
113 std::unique_ptr<ScoreExpression> expression =
114 std::unique_ptr<OperatorScoreExpression>(
115 new OperatorScoreExpression(op, std::move(children)));
116 if (children_all_constant_double) {
117 // Because all of the children are constants, this expression does not
118 // depend on the DocHitInto or query_it that are passed into it.
119 ICING_ASSIGN_OR_RETURN(double constant_value,
120 expression->EvaluateDouble(DocHitInfo(),
121 /*query_it=*/nullptr));
122 return ConstantScoreExpression::Create(constant_value);
123 }
124 return expression;
125 }
126
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const127 libtextclassifier3::StatusOr<double> OperatorScoreExpression::EvaluateDouble(
128 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
129 // The Create factory guarantees that an operator will have at least one
130 // child.
131 ICING_ASSIGN_OR_RETURN(double res,
132 children_.at(0)->EvaluateDouble(hit_info, query_it));
133
134 if (op_ == OperatorType::kNegative) {
135 return -res;
136 }
137
138 for (int i = 1; i < children_.size(); ++i) {
139 ICING_ASSIGN_OR_RETURN(double v,
140 children_.at(i)->EvaluateDouble(hit_info, query_it));
141 switch (op_) {
142 case OperatorType::kPlus:
143 res += v;
144 break;
145 case OperatorType::kMinus:
146 res -= v;
147 break;
148 case OperatorType::kTimes:
149 res *= v;
150 break;
151 case OperatorType::kDiv:
152 res /= v;
153 break;
154 case OperatorType::kNegative:
155 return absl_ports::InternalError("Should never reach here.");
156 }
157 if (!std::isfinite(res)) {
158 return absl_ports::InvalidArgumentError(
159 "Got a non-finite value while evaluating operator score expression.");
160 }
161 }
162 return res;
163 }
164
165 const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType>
166 MathFunctionScoreExpression::kFunctionNames = {
167 {"log", FunctionType::kLog},
168 {"pow", FunctionType::kPow},
169 {"max", FunctionType::kMax},
170 {"min", FunctionType::kMin},
171 {"len", FunctionType::kLen},
172 {"sum", FunctionType::kSum},
173 {"avg", FunctionType::kAvg},
174 {"sqrt", FunctionType::kSqrt},
175 {"abs", FunctionType::kAbs},
176 {"sin", FunctionType::kSin},
177 {"cos", FunctionType::kCos},
178 {"tan", FunctionType::kTan},
179 {"maxOrDefault", FunctionType::kMaxOrDefault},
180 {"minOrDefault", FunctionType::kMinOrDefault}};
181
182 const std::unordered_set<MathFunctionScoreExpression::FunctionType>
183 MathFunctionScoreExpression::kVariableArgumentsFunctions = {
184 FunctionType::kMax, FunctionType::kMin, FunctionType::kLen,
185 FunctionType::kSum, FunctionType::kAvg};
186
187 const std::unordered_set<MathFunctionScoreExpression::FunctionType>
188 MathFunctionScoreExpression::kListArgumentFunctions = {
189 FunctionType::kMaxOrDefault, FunctionType::kMinOrDefault};
190
191 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)192 MathFunctionScoreExpression::Create(
193 FunctionType function_type,
194 std::vector<std::unique_ptr<ScoreExpression>> args) {
195 if (args.empty()) {
196 return absl_ports::InvalidArgumentError(
197 "Math functions must have at least one argument.");
198 }
199 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
200
201 // Return early for functions that support variable length arguments and the
202 // first argument is a list.
203 if (args.size() == 1 && args[0]->type() == ScoreExpressionType::kDoubleList &&
204 kVariableArgumentsFunctions.count(function_type) > 0) {
205 return std::unique_ptr<MathFunctionScoreExpression>(
206 new MathFunctionScoreExpression(function_type, std::move(args)));
207 }
208
209 bool args_all_constant_double = false;
210 if (kListArgumentFunctions.count(function_type) > 0) {
211 if (args[0]->type() != ScoreExpressionType::kDoubleList) {
212 return absl_ports::InvalidArgumentError(
213 "Got an invalid type for the math function. Should expect a list "
214 "type value in the first argument.");
215 }
216 } else {
217 args_all_constant_double = true;
218 for (const auto& child : args) {
219 if (child->type() != ScoreExpressionType::kDouble) {
220 return absl_ports::InvalidArgumentError(
221 "Got an invalid type for the math function. Should expect a double "
222 "type argument.");
223 }
224 if (!child->is_constant()) {
225 args_all_constant_double = false;
226 }
227 }
228 }
229 switch (function_type) {
230 case FunctionType::kLog:
231 if (args.size() != 1 && args.size() != 2) {
232 return absl_ports::InvalidArgumentError(
233 "log must have 1 or 2 arguments.");
234 }
235 break;
236 case FunctionType::kPow:
237 if (args.size() != 2) {
238 return absl_ports::InvalidArgumentError("pow must have 2 arguments.");
239 }
240 break;
241 case FunctionType::kSqrt:
242 if (args.size() != 1) {
243 return absl_ports::InvalidArgumentError("sqrt must have 1 argument.");
244 }
245 break;
246 case FunctionType::kAbs:
247 if (args.size() != 1) {
248 return absl_ports::InvalidArgumentError("abs must have 1 argument.");
249 }
250 break;
251 case FunctionType::kSin:
252 if (args.size() != 1) {
253 return absl_ports::InvalidArgumentError("sin must have 1 argument.");
254 }
255 break;
256 case FunctionType::kCos:
257 if (args.size() != 1) {
258 return absl_ports::InvalidArgumentError("cos must have 1 argument.");
259 }
260 break;
261 case FunctionType::kTan:
262 if (args.size() != 1) {
263 return absl_ports::InvalidArgumentError("tan must have 1 argument.");
264 }
265 break;
266 case FunctionType::kMaxOrDefault:
267 if (args.size() != 2) {
268 return absl_ports::InvalidArgumentError(
269 "maxOrDefault must have 2 arguments.");
270 }
271 if (args[1]->type() != ScoreExpressionType::kDouble) {
272 return absl_ports::InvalidArgumentError(
273 "maxOrDefault must have a double type as the second argument.");
274 }
275 break;
276 case FunctionType::kMinOrDefault:
277 if (args.size() != 2) {
278 return absl_ports::InvalidArgumentError(
279 "minOrDefault must have 2 arguments.");
280 }
281 if (args[1]->type() != ScoreExpressionType::kDouble) {
282 return absl_ports::InvalidArgumentError(
283 "minOrDefault must have a double type as the second argument.");
284 }
285 break;
286 // Functions that support variable length arguments
287 case FunctionType::kMax:
288 [[fallthrough]];
289 case FunctionType::kMin:
290 [[fallthrough]];
291 case FunctionType::kLen:
292 [[fallthrough]];
293 case FunctionType::kSum:
294 [[fallthrough]];
295 case FunctionType::kAvg:
296 break;
297 }
298 std::unique_ptr<ScoreExpression> expression =
299 std::unique_ptr<MathFunctionScoreExpression>(
300 new MathFunctionScoreExpression(function_type, std::move(args)));
301 if (args_all_constant_double) {
302 // Because all of the arguments are constants, this expression does not
303 // depend on the DocHitInto or query_it that are passed into it.
304 ICING_ASSIGN_OR_RETURN(double constant_value,
305 expression->EvaluateDouble(DocHitInfo(),
306 /*query_it=*/nullptr));
307 return ConstantScoreExpression::Create(constant_value);
308 }
309 return expression;
310 }
311
312 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const313 MathFunctionScoreExpression::EvaluateDouble(
314 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
315 std::vector<double> values;
316 int ind = 0;
317 if (args_.at(0)->type() == ScoreExpressionType::kDoubleList) {
318 ICING_ASSIGN_OR_RETURN(values,
319 args_.at(0)->EvaluateList(hit_info, query_it));
320 ind = 1;
321 }
322 for (; ind < args_.size(); ++ind) {
323 ICING_ASSIGN_OR_RETURN(double v,
324 args_.at(ind)->EvaluateDouble(hit_info, query_it));
325 values.push_back(v);
326 }
327
328 double res = 0;
329 switch (function_type_) {
330 case FunctionType::kLog:
331 if (values.size() == 1) {
332 res = log(values[0]);
333 } else {
334 // argument 0 is log base
335 // argument 1 is the value
336 res = log(values[1]) / log(values[0]);
337 }
338 break;
339 case FunctionType::kPow:
340 res = pow(values[0], values[1]);
341 break;
342 case FunctionType::kMax:
343 if (values.empty()) {
344 return absl_ports::InvalidArgumentError(
345 "Got an empty parameter set in max function");
346 }
347 res = *std::max_element(values.begin(), values.end());
348 break;
349 case FunctionType::kMin:
350 if (values.empty()) {
351 return absl_ports::InvalidArgumentError(
352 "Got an empty parameter set in min function");
353 }
354 res = *std::min_element(values.begin(), values.end());
355 break;
356 case FunctionType::kLen:
357 res = values.size();
358 break;
359 case FunctionType::kSum:
360 res = std::reduce(values.begin(), values.end());
361 break;
362 case FunctionType::kAvg:
363 if (values.empty()) {
364 return absl_ports::InvalidArgumentError(
365 "Got an empty parameter set in avg function.");
366 }
367 res = std::reduce(values.begin(), values.end()) / values.size();
368 break;
369 case FunctionType::kSqrt:
370 res = sqrt(values[0]);
371 break;
372 case FunctionType::kAbs:
373 res = abs(values[0]);
374 break;
375 case FunctionType::kSin:
376 res = sin(values[0]);
377 break;
378 case FunctionType::kCos:
379 res = cos(values[0]);
380 break;
381 case FunctionType::kTan:
382 res = tan(values[0]);
383 break;
384 // For the following two functions, the last value is the default value.
385 // If values.size() == 1, then it means the provided list is empty.
386 case FunctionType::kMaxOrDefault:
387 if (values.size() == 1) {
388 res = values[0];
389 } else {
390 res = *std::max_element(values.begin(), values.end() - 1);
391 }
392 break;
393 case FunctionType::kMinOrDefault:
394 if (values.size() == 1) {
395 res = values[0];
396 } else {
397 res = *std::min_element(values.begin(), values.end() - 1);
398 }
399 break;
400 }
401 if (!std::isfinite(res)) {
402 return absl_ports::InvalidArgumentError(
403 "Got a non-finite value while evaluating math function score "
404 "expression.");
405 }
406 return res;
407 }
408
409 const std::unordered_map<std::string,
410 ListOperationFunctionScoreExpression::FunctionType>
411 ListOperationFunctionScoreExpression::kFunctionNames = {
412 {"filterByRange", FunctionType::kFilterByRange}};
413
414 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)415 ListOperationFunctionScoreExpression::Create(
416 FunctionType function_type,
417 std::vector<std::unique_ptr<ScoreExpression>> args) {
418 if (args.empty()) {
419 return absl_ports::InvalidArgumentError(
420 "List operation functions must have at least one argument.");
421 }
422 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
423
424 switch (function_type) {
425 case FunctionType::kFilterByRange:
426 if (args.size() != 3) {
427 return absl_ports::InvalidArgumentError(
428 "filterByRange must have 3 arguments.");
429 }
430 if (args[0]->type() != ScoreExpressionType::kDoubleList) {
431 return absl_ports::InvalidArgumentError(
432 "Should expect a list type value for the first argument of "
433 "filterByRange.");
434 }
435 if (args.at(1)->type() != ScoreExpressionType::kDouble ||
436 args.at(2)->type() != ScoreExpressionType::kDouble) {
437 return absl_ports::InvalidArgumentError(
438 "Should expect double type values for the second and third "
439 "arguments of filterByRange.");
440 }
441 break;
442 }
443 return std::unique_ptr<ListOperationFunctionScoreExpression>(
444 new ListOperationFunctionScoreExpression(function_type, std::move(args)));
445 }
446
447 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const448 ListOperationFunctionScoreExpression::EvaluateList(
449 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
450 switch (function_type_) {
451 case FunctionType::kFilterByRange:
452 ICING_ASSIGN_OR_RETURN(std::vector<double> list_value,
453 args_.at(0)->EvaluateList(hit_info, query_it));
454 ICING_ASSIGN_OR_RETURN(double low,
455 args_.at(1)->EvaluateDouble(hit_info, query_it));
456 ICING_ASSIGN_OR_RETURN(double high,
457 args_.at(2)->EvaluateDouble(hit_info, query_it));
458 if (low > high) {
459 return absl_ports::InvalidArgumentError(
460 "The lower bound cannot be greater than the upper bound.");
461 }
462 auto new_end =
463 std::remove_if(list_value.begin(), list_value.end(),
464 [low, high](double v) { return v < low || v > high; });
465 list_value.erase(new_end, list_value.end());
466 return list_value;
467 break;
468 }
469 return absl_ports::InternalError("Should never reach here.");
470 }
471
472 const std::unordered_map<std::string,
473 DocumentFunctionScoreExpression::FunctionType>
474 DocumentFunctionScoreExpression::kFunctionNames = {
475 {"documentScore", FunctionType::kDocumentScore},
476 {"creationTimestamp", FunctionType::kCreationTimestamp},
477 {"usageCount", FunctionType::kUsageCount},
478 {"usageLastUsedTimestamp", FunctionType::kUsageLastUsedTimestamp}};
479
480 libtextclassifier3::StatusOr<std::unique_ptr<DocumentFunctionScoreExpression>>
Create(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)481 DocumentFunctionScoreExpression::Create(
482 FunctionType function_type,
483 std::vector<std::unique_ptr<ScoreExpression>> args,
484 const DocumentStore* document_store, double default_score,
485 int64_t current_time_ms) {
486 if (args.empty()) {
487 return absl_ports::InvalidArgumentError(
488 "Document-based functions must have at least one argument.");
489 }
490 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
491
492 if (args[0]->type() != ScoreExpressionType::kDocument) {
493 return absl_ports::InvalidArgumentError(
494 "The first parameter of document-based functions must be \"this\".");
495 }
496 switch (function_type) {
497 case FunctionType::kDocumentScore:
498 [[fallthrough]];
499 case FunctionType::kCreationTimestamp:
500 if (args.size() != 1) {
501 return absl_ports::InvalidArgumentError(
502 "DocumentScore/CreationTimestamp must have 1 argument.");
503 }
504 break;
505 case FunctionType::kUsageCount:
506 [[fallthrough]];
507 case FunctionType::kUsageLastUsedTimestamp:
508 if (args.size() != 2 || args[1]->type() != ScoreExpressionType::kDouble) {
509 return absl_ports::InvalidArgumentError(
510 "UsageCount/UsageLastUsedTimestamp must have 2 arguments. The "
511 "first argument should be \"this\", and the second argument "
512 "should be the usage type.");
513 }
514 break;
515 }
516 return std::unique_ptr<DocumentFunctionScoreExpression>(
517 new DocumentFunctionScoreExpression(function_type, std::move(args),
518 document_store, default_score,
519 current_time_ms));
520 }
521
522 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const523 DocumentFunctionScoreExpression::EvaluateDouble(
524 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
525 switch (function_type_) {
526 case FunctionType::kDocumentScore:
527 [[fallthrough]];
528 case FunctionType::kCreationTimestamp: {
529 ICING_ASSIGN_OR_RETURN(DocumentAssociatedScoreData score_data,
530 document_store_.GetDocumentAssociatedScoreData(
531 hit_info.document_id()),
532 default_score_);
533 if (function_type_ == FunctionType::kDocumentScore) {
534 return static_cast<double>(score_data.document_score());
535 }
536 return static_cast<double>(score_data.creation_timestamp_ms());
537 }
538 case FunctionType::kUsageCount:
539 [[fallthrough]];
540 case FunctionType::kUsageLastUsedTimestamp: {
541 ICING_ASSIGN_OR_RETURN(double raw_usage_type,
542 args_[1]->EvaluateDouble(hit_info, query_it));
543 int usage_type = (int)raw_usage_type;
544 if (usage_type < 1 || usage_type > 3 || raw_usage_type != usage_type) {
545 return absl_ports::InvalidArgumentError(
546 "Usage type must be an integer from 1 to 3");
547 }
548 std::optional<UsageStore::UsageScores> usage_scores =
549 document_store_.GetUsageScores(hit_info.document_id(),
550 current_time_ms_);
551 if (!usage_scores) {
552 // If there's no UsageScores entry present for this doc, then just
553 // treat it as a default instance.
554 usage_scores = UsageStore::UsageScores();
555 }
556 if (function_type_ == FunctionType::kUsageCount) {
557 if (usage_type == 1) {
558 return usage_scores->usage_type1_count;
559 } else if (usage_type == 2) {
560 return usage_scores->usage_type2_count;
561 } else {
562 return usage_scores->usage_type3_count;
563 }
564 }
565 if (usage_type == 1) {
566 return usage_scores->usage_type1_last_used_timestamp_s * 1000.0;
567 } else if (usage_type == 2) {
568 return usage_scores->usage_type2_last_used_timestamp_s * 1000.0;
569 } else {
570 return usage_scores->usage_type3_last_used_timestamp_s * 1000.0;
571 }
572 }
573 }
574 }
575
576 libtextclassifier3::StatusOr<
577 std::unique_ptr<RelevanceScoreFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,Bm25fCalculator * bm25f_calculator,double default_score)578 RelevanceScoreFunctionScoreExpression::Create(
579 std::vector<std::unique_ptr<ScoreExpression>> args,
580 Bm25fCalculator* bm25f_calculator, double default_score) {
581 if (args.size() != 1) {
582 return absl_ports::InvalidArgumentError(
583 "relevanceScore must have 1 argument.");
584 }
585 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
586
587 if (args[0]->type() != ScoreExpressionType::kDocument) {
588 return absl_ports::InvalidArgumentError(
589 "relevanceScore must take \"this\" as its argument.");
590 }
591 return std::unique_ptr<RelevanceScoreFunctionScoreExpression>(
592 new RelevanceScoreFunctionScoreExpression(bm25f_calculator,
593 default_score));
594 }
595
596 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const597 RelevanceScoreFunctionScoreExpression::EvaluateDouble(
598 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
599 if (query_it == nullptr) {
600 return default_score_;
601 }
602 return static_cast<double>(
603 bm25f_calculator_.ComputeScore(query_it, hit_info, default_score_));
604 }
605
606 libtextclassifier3::StatusOr<
607 std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore & document_store,const JoinChildrenFetcher * join_children_fetcher,int64_t current_time_ms)608 ChildrenRankingSignalsFunctionScoreExpression::Create(
609 std::vector<std::unique_ptr<ScoreExpression>> args,
610 const DocumentStore& document_store,
611 const JoinChildrenFetcher* join_children_fetcher, int64_t current_time_ms) {
612 if (args.size() != 1) {
613 return absl_ports::InvalidArgumentError(
614 "childrenRankingSignals must have 1 argument.");
615 }
616 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
617
618 if (args[0]->type() != ScoreExpressionType::kDocument) {
619 return absl_ports::InvalidArgumentError(
620 "childrenRankingSignals must take \"this\" as its argument.");
621 }
622 if (join_children_fetcher == nullptr) {
623 return absl_ports::InvalidArgumentError(
624 "childrenRankingSignals must only be used with join, but "
625 "JoinChildrenFetcher is not provided.");
626 }
627 return std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>(
628 new ChildrenRankingSignalsFunctionScoreExpression(
629 document_store, *join_children_fetcher, current_time_ms));
630 }
631
632 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const633 ChildrenRankingSignalsFunctionScoreExpression::EvaluateList(
634 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
635 ICING_ASSIGN_OR_RETURN(
636 std::vector<ScoredDocumentHit> children_hits,
637 join_children_fetcher_.GetChildren(hit_info.document_id()));
638 std::vector<double> children_scores;
639 children_scores.reserve(children_hits.size());
640 for (const ScoredDocumentHit& child_hit : children_hits) {
641 children_scores.push_back(child_hit.score());
642 }
643 return std::move(children_scores);
644 }
645
646 libtextclassifier3::StatusOr<
647 std::unique_ptr<PropertyWeightsFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)648 PropertyWeightsFunctionScoreExpression::Create(
649 std::vector<std::unique_ptr<ScoreExpression>> args,
650 const DocumentStore* document_store, const SectionWeights* section_weights,
651 int64_t current_time_ms) {
652 if (args.size() != 1) {
653 return absl_ports::InvalidArgumentError(
654 "propertyWeights must have 1 argument.");
655 }
656 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
657
658 if (args[0]->type() != ScoreExpressionType::kDocument) {
659 return absl_ports::InvalidArgumentError(
660 "propertyWeights must take \"this\" as its argument.");
661 }
662 return std::unique_ptr<PropertyWeightsFunctionScoreExpression>(
663 new PropertyWeightsFunctionScoreExpression(
664 document_store, section_weights, current_time_ms));
665 }
666
667 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator *) const668 PropertyWeightsFunctionScoreExpression::EvaluateList(
669 const DocHitInfo& hit_info, const DocHitInfoIterator*) const {
670 std::vector<double> weights;
671 SectionIdMask sections = hit_info.hit_section_ids_mask();
672 SchemaTypeId schema_type_id = GetSchemaTypeId(
673 hit_info.document_id(), document_store_, current_time_ms_);
674
675 while (sections != 0) {
676 SectionId section_id = __builtin_ctzll(sections);
677 sections &= ~(UINT64_C(1) << section_id);
678 weights.push_back(section_weights_.GetNormalizedSectionWeight(
679 schema_type_id, section_id));
680 }
681 return weights;
682 }
683
684 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args)685 GetEmbeddingParameterFunctionScoreExpression::Create(
686 std::vector<std::unique_ptr<ScoreExpression>> args) {
687 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
688
689 if (args.size() != 1) {
690 return absl_ports::InvalidArgumentError(
691 absl_ports::StrCat(kFunctionName, " must have 1 argument."));
692 }
693 if (args[0]->type() != ScoreExpressionType::kDouble) {
694 return absl_ports::InvalidArgumentError(
695 absl_ports::StrCat(kFunctionName, " got invalid argument type."));
696 }
697 bool is_constant = args[0]->is_constant();
698 std::unique_ptr<ScoreExpression> expression =
699 std::unique_ptr<GetEmbeddingParameterFunctionScoreExpression>(
700 new GetEmbeddingParameterFunctionScoreExpression(std::move(args[0])));
701 if (is_constant) {
702 ICING_ASSIGN_OR_RETURN(double constant_value,
703 expression->EvaluateDouble(DocHitInfo(),
704 /*query_it=*/nullptr));
705 return ConstantScoreExpression::Create(constant_value, expression->type());
706 }
707 return expression;
708 }
709
710 libtextclassifier3::StatusOr<double>
EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const711 GetEmbeddingParameterFunctionScoreExpression::EvaluateDouble(
712 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
713 ICING_ASSIGN_OR_RETURN(double raw_query_index,
714 arg_->EvaluateDouble(hit_info, query_it));
715 if (raw_query_index < 0) {
716 return absl_ports::InvalidArgumentError(
717 "The index of an embedding query must be a non-negative integer.");
718 }
719 if (raw_query_index > std::numeric_limits<uint32_t>::max()) {
720 return absl_ports::InvalidArgumentError(
721 "The index of an embedding query exceeds the maximum value of uint32.");
722 }
723 uint32_t query_index = (uint32_t)raw_query_index;
724 if (query_index != raw_query_index) {
725 return absl_ports::InvalidArgumentError(
726 "The index of an embedding query must be an integer.");
727 }
728 return query_index;
729 }
730
731 libtextclassifier3::StatusOr<
732 std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,const EmbeddingQueryResults * embedding_query_results)733 MatchedSemanticScoresFunctionScoreExpression::Create(
734 std::vector<std::unique_ptr<ScoreExpression>> args,
735 SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type,
736 const EmbeddingQueryResults* embedding_query_results) {
737 ICING_RETURN_ERROR_IF_NULL(embedding_query_results);
738 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
739
740 if (args.empty() || args[0]->type() != ScoreExpressionType::kDocument) {
741 return absl_ports::InvalidArgumentError(
742 absl_ports::StrCat(kFunctionName, " is not called with \"this\""));
743 }
744 if (args.size() != 2 && args.size() != 3) {
745 return absl_ports::InvalidArgumentError(
746 absl_ports::StrCat(kFunctionName, " got invalid number of arguments."));
747 }
748 ScoreExpression* embedding_index_arg = args[1].get();
749 if (embedding_index_arg->type() != ScoreExpressionType::kVectorIndex) {
750 return absl_ports::InvalidArgumentError(absl_ports::StrCat(
751 kFunctionName, " got invalid argument type for embedding vector."));
752 }
753 if (args.size() == 3 && args[2]->type() != ScoreExpressionType::kString) {
754 return absl_ports::InvalidArgumentError(
755 "Embedding metric can only be given as a string.");
756 }
757
758 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type =
759 default_metric_type;
760 if (args.size() == 3) {
761 if (!args[2]->is_constant()) {
762 return absl_ports::InvalidArgumentError(
763 "Embedding metric can only be given as a constant string.");
764 }
765 ICING_ASSIGN_OR_RETURN(std::string_view metric, args[2]->EvaluateString());
766 ICING_ASSIGN_OR_RETURN(
767 metric_type,
768 embedding_util::GetEmbeddingQueryMetricTypeFromName(metric));
769 }
770 if (embedding_index_arg->is_constant()) {
771 ICING_ASSIGN_OR_RETURN(
772 uint32_t embedding_index,
773 embedding_index_arg->EvaluateDouble(DocHitInfo(),
774 /*query_it=*/nullptr));
775 if (embedding_query_results->GetScoreMap(embedding_index, metric_type) ==
776 nullptr) {
777 return absl_ports::InvalidArgumentError(absl_ports::StrCat(
778 "The embedding query index ", std::to_string(embedding_index),
779 " with metric type ",
780 SearchSpecProto::EmbeddingQueryMetricType::Code_Name(metric_type),
781 " has not been queried."));
782 }
783 }
784 return std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>(
785 new MatchedSemanticScoresFunctionScoreExpression(
786 std::move(args), metric_type, *embedding_query_results));
787 }
788
789 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const790 MatchedSemanticScoresFunctionScoreExpression::EvaluateList(
791 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
792 ICING_ASSIGN_OR_RETURN(double raw_query_index,
793 args_[1]->EvaluateDouble(hit_info, query_it));
794 uint32_t query_index = (uint32_t)raw_query_index;
795 const std::vector<double>* scores =
796 embedding_query_results_.GetMatchedScoresForDocument(
797 query_index, metric_type_, hit_info.document_id());
798 if (scores == nullptr) {
799 return std::vector<double>();
800 }
801 return *scores;
802 }
803
804 GetScorablePropertyFunctionScoreExpression::
GetScorablePropertyFunctionScoreExpression(const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms,std::unordered_set<SchemaTypeId> && schema_type_ids,std::string_view property_path)805 GetScorablePropertyFunctionScoreExpression(
806 const DocumentStore* document_store, const SchemaStore* schema_store,
807 int64_t current_time_ms,
808 std::unordered_set<SchemaTypeId>&& schema_type_ids,
809 std::string_view property_path)
810 : document_store_(*document_store),
811 schema_store_(*schema_store),
812 current_time_ms_(current_time_ms),
813 schema_type_ids_(std::move(schema_type_ids)),
814 property_path_(property_path) {}
815
816 libtextclassifier3::StatusOr<std::unordered_set<SchemaTypeId>>
GetAndValidateSchemaTypeIds(std::string_view alias_schema_type,std::string_view property_path,const SchemaTypeAliasMap & schema_type_alias_map,const SchemaStore & schema_store)817 GetScorablePropertyFunctionScoreExpression::GetAndValidateSchemaTypeIds(
818 std::string_view alias_schema_type, std::string_view property_path,
819 const SchemaTypeAliasMap& schema_type_alias_map,
820 const SchemaStore& schema_store) {
821 auto alias_map_iter = schema_type_alias_map.find(alias_schema_type.data());
822 if (alias_map_iter == schema_type_alias_map.end()) {
823 return absl_ports::InvalidArgumentError(absl_ports::StrCat(
824 "The alias schema type in the score expression is not found in the "
825 "schema_type_alias_map: ",
826 alias_schema_type));
827 }
828
829 std::unordered_set<SchemaTypeId> schema_type_ids;
830 for (std::string_view schema_type : alias_map_iter->second) {
831 // First, verify that the schema type has a valid schema type id in the
832 // schema store.
833 libtextclassifier3::StatusOr<SchemaTypeId> schema_type_id_or =
834 schema_store.GetSchemaTypeId(schema_type);
835 if (!schema_type_id_or.ok()) {
836 if (absl_ports::IsNotFound(schema_type_id_or.status())) {
837 // Ignores the schema type if it is not found in the schema store.
838 continue;
839 }
840 return schema_type_id_or.status();
841 }
842 SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie();
843
844 // Then, calls GetScorablePropertyIndex() here to validate if the property
845 // path is scorable under the schema type, no need to check the returned
846 // index value.
847 libtextclassifier3::StatusOr<std::optional<int>>
848 scorable_property_index_or = schema_store.GetScorablePropertyIndex(
849 schema_type_id, property_path);
850 if (!scorable_property_index_or.ok()) {
851 return scorable_property_index_or.status();
852 }
853 if (!scorable_property_index_or.ValueOrDie().has_value()) {
854 return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
855 "'%s' is not defined as a scorable property under schema type %d",
856 property_path.data(), schema_type_id));
857 }
858 schema_type_ids.insert(schema_type_id);
859 }
860 return schema_type_ids;
861 }
862
863 libtextclassifier3::StatusOr<
864 std::unique_ptr<GetScorablePropertyFunctionScoreExpression>>
Create(std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,const SchemaStore * schema_store,const SchemaTypeAliasMap & schema_type_alias_map,int64_t current_time_ms)865 GetScorablePropertyFunctionScoreExpression::Create(
866 std::vector<std::unique_ptr<ScoreExpression>> args,
867 const DocumentStore* document_store, const SchemaStore* schema_store,
868 const SchemaTypeAliasMap& schema_type_alias_map, int64_t current_time_ms) {
869 ICING_RETURN_IF_ERROR(CheckChildrenNotNull(args));
870
871 if (args.size() != 2 || args[0]->type() != ScoreExpressionType::kString ||
872 args[1]->type() != ScoreExpressionType::kString) {
873 return absl_ports::InvalidArgumentError(absl_ports::StrCat(
874 kFunctionName, " must take exactly two string params"));
875 }
876
877 // Validate schema type.
878 ICING_ASSIGN_OR_RETURN(std::string_view alias_schema_type,
879 args[0]->EvaluateString());
880 ICING_ASSIGN_OR_RETURN(std::string_view property_path,
881 args[1]->EvaluateString());
882 ICING_ASSIGN_OR_RETURN(
883 std::unordered_set<SchemaTypeId> schema_type_ids,
884 GetAndValidateSchemaTypeIds(alias_schema_type, property_path,
885 schema_type_alias_map, *schema_store));
886
887 return std::unique_ptr<GetScorablePropertyFunctionScoreExpression>(
888 new GetScorablePropertyFunctionScoreExpression(
889 document_store, schema_store, current_time_ms,
890 std::move(schema_type_ids), property_path));
891 }
892
893 libtextclassifier3::StatusOr<std::vector<double>>
EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it) const894 GetScorablePropertyFunctionScoreExpression::EvaluateList(
895 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const {
896 SchemaTypeId doc_schema_type_id = GetSchemaTypeId(
897 hit_info.document_id(), document_store_, current_time_ms_);
898 if (schema_type_ids_.find(doc_schema_type_id) == schema_type_ids_.end()) {
899 return std::vector<double>();
900 }
901
902 std::unique_ptr<ScorablePropertySet> scorable_property_set =
903 document_store_.GetScorablePropertySet(hit_info.document_id(),
904 current_time_ms_);
905 // It should never happen.
906 if (scorable_property_set == nullptr) {
907 return absl_ports::InternalError(IcingStringUtil::StringPrintf(
908 "Failed to retrieve ScorablePropertySet for document %d",
909 hit_info.document_id()));
910 }
911
912 const ScorablePropertyProto* scorable_property_proto =
913 scorable_property_set->GetScorablePropertyProto(property_path_);
914 // It should never happen as icing generates a default value for each scorable
915 // property when the document is created.
916 if (scorable_property_proto == nullptr) {
917 return absl_ports::InternalError(IcingStringUtil::StringPrintf(
918 "Failed to retrieve ScorablePropertyProto for document %d, and "
919 "property path %s",
920 hit_info.document_id(), property_path_.c_str()));
921 }
922
923 // Converts ScorablePropertyProto to a vector of doubles.
924 if (scorable_property_proto->int64_values_size() > 0) {
925 return std::vector<double>(scorable_property_proto->int64_values().begin(),
926 scorable_property_proto->int64_values().end());
927 } else if (scorable_property_proto->double_values_size() > 0) {
928 return std::vector<double>(scorable_property_proto->double_values().begin(),
929 scorable_property_proto->double_values().end());
930 } else if (scorable_property_proto->boolean_values_size() > 0) {
931 return std::vector<double>(
932 scorable_property_proto->boolean_values().begin(),
933 scorable_property_proto->boolean_values().end());
934 }
935 return std::vector<double>();
936 }
937
938 } // namespace lib
939 } // namespace icing
940