1 // Copyright (C) 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 16 #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <string> 21 #include <string_view> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "icing/text_classifier/lib3/utils/base/statusor.h" 28 #include "icing/absl_ports/canonical_errors.h" 29 #include "icing/index/embed/embedding-query-results.h" 30 #include "icing/index/hit/doc-hit-info.h" 31 #include "icing/index/iterator/doc-hit-info-iterator.h" 32 #include "icing/join/join-children-fetcher.h" 33 #include "icing/schema/schema-store.h" 34 #include "icing/scoring/bm25f-calculator.h" 35 #include "icing/scoring/section-weights.h" 36 #include "icing/store/document-filter-data.h" 37 #include "icing/store/document-id.h" 38 #include "icing/store/document-store.h" 39 #include "icing/util/status-macros.h" 40 41 namespace icing { 42 namespace lib { 43 44 enum class ScoreExpressionType { 45 kDouble, 46 kDoubleList, 47 kDocument, // Only "this" is considered as document type. 48 // TODO(b/326656531): Instead of creating a vector index type, consider 49 // changing it to vector type so that the data is the vector directly. 50 kVectorIndex, 51 kString, 52 }; 53 54 // A map from alias schema type to a set of Icing schema types. 55 using SchemaTypeAliasMap = 56 std::unordered_map<std::string, std::unordered_set<std::string>>; 57 58 class ScoreExpression { 59 public: 60 virtual ~ScoreExpression() = default; 61 62 // Evaluate the score expression to double with the current document. 63 // 64 // RETURNS: 65 // - The evaluated result as a double on success. 66 // - INVALID_ARGUMENT if a non-finite value is reached while evaluating the 67 // expression. 68 // - INTERNAL if there are inconsistencies. EvaluateDouble(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)69 virtual libtextclassifier3::StatusOr<double> EvaluateDouble( 70 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { 71 if (type() == ScoreExpressionType::kDouble) { 72 return absl_ports::UnimplementedError( 73 "All ScoreExpressions of type double must provide their own " 74 "implementation of EvaluateDouble!"); 75 } 76 return absl_ports::InternalError( 77 "Runtime type error: the expression should never be evaluated to a " 78 "double. There must be inconsistencies in the static type checking."); 79 } 80 EvaluateList(const DocHitInfo & hit_info,const DocHitInfoIterator * query_it)81 virtual libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 82 const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) const { 83 if (type() == ScoreExpressionType::kDoubleList) { 84 return absl_ports::UnimplementedError( 85 "All ScoreExpressions of type double List must provide their own " 86 "implementation of EvaluateList!"); 87 } 88 return absl_ports::InternalError( 89 "Runtime type error: the expression should never be evaluated to a " 90 "double list. There must be inconsistencies in the static type " 91 "checking."); 92 } 93 EvaluateString()94 virtual libtextclassifier3::StatusOr<std::string_view> EvaluateString() 95 const { 96 if (type() == ScoreExpressionType::kString) { 97 return absl_ports::UnimplementedError( 98 "All ScoreExpressions of type string must provide their own " 99 "implementation of EvaluateString!"); 100 } 101 return absl_ports::InternalError( 102 "Runtime type error: the expression should never be evaluated to a " 103 "string. There must be inconsistencies in the static type checking."); 104 } 105 106 // Indicate the type to which the current expression will be evaluated. 107 virtual ScoreExpressionType type() const = 0; 108 109 // Indicate whether the current expression is a constant. 110 // Returns true if and only if the object is of ConstantScoreExpression or 111 // StringExpression type. is_constant()112 virtual bool is_constant() const { return false; } 113 }; 114 115 class ThisExpression : public ScoreExpression { 116 public: Create()117 static std::unique_ptr<ThisExpression> Create() { 118 return std::unique_ptr<ThisExpression>(new ThisExpression()); 119 } 120 type()121 ScoreExpressionType type() const override { 122 return ScoreExpressionType::kDocument; 123 } 124 125 private: 126 ThisExpression() = default; 127 }; 128 129 class ConstantScoreExpression : public ScoreExpression { 130 public: 131 static std::unique_ptr<ConstantScoreExpression> Create( 132 double c, ScoreExpressionType type = ScoreExpressionType::kDouble) { 133 return std::unique_ptr<ConstantScoreExpression>( 134 new ConstantScoreExpression(c, type)); 135 } 136 EvaluateDouble(const DocHitInfo &,const DocHitInfoIterator *)137 libtextclassifier3::StatusOr<double> EvaluateDouble( 138 const DocHitInfo&, const DocHitInfoIterator*) const override { 139 return c_; 140 } 141 type()142 ScoreExpressionType type() const override { return type_; } 143 is_constant()144 bool is_constant() const override { return true; } 145 146 private: ConstantScoreExpression(double c,ScoreExpressionType type)147 explicit ConstantScoreExpression(double c, ScoreExpressionType type) 148 : c_(c), type_(type) {} 149 150 double c_; 151 ScoreExpressionType type_; 152 }; 153 154 class StringExpression : public ScoreExpression { 155 public: Create(std::string str)156 static std::unique_ptr<StringExpression> Create(std::string str) { 157 return std::unique_ptr<StringExpression>( 158 new StringExpression(std::move(str))); 159 } 160 EvaluateString()161 libtextclassifier3::StatusOr<std::string_view> EvaluateString() 162 const override { 163 return str_; 164 } 165 type()166 ScoreExpressionType type() const override { 167 return ScoreExpressionType::kString; 168 } 169 is_constant()170 bool is_constant() const override { return true; } 171 172 private: StringExpression(std::string str)173 explicit StringExpression(std::string str) : str_(std::move(str)) {} 174 std::string str_; 175 }; 176 177 class OperatorScoreExpression : public ScoreExpression { 178 public: 179 enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv }; 180 181 // RETURNS: 182 // - An OperatorScoreExpression instance on success if not simplifiable. 183 // - A ConstantScoreExpression instance on success if simplifiable. 184 // - FAILED_PRECONDITION on any null pointer in children. 185 // - INVALID_ARGUMENT on type errors. 186 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 187 OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children); 188 189 libtextclassifier3::StatusOr<double> EvaluateDouble( 190 const DocHitInfo& hit_info, 191 const DocHitInfoIterator* query_it) const override; 192 type()193 ScoreExpressionType type() const override { 194 return ScoreExpressionType::kDouble; 195 } 196 197 private: OperatorScoreExpression(OperatorType op,std::vector<std::unique_ptr<ScoreExpression>> children)198 explicit OperatorScoreExpression( 199 OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) 200 : op_(op), children_(std::move(children)) {} 201 202 OperatorType op_; 203 std::vector<std::unique_ptr<ScoreExpression>> children_; 204 }; 205 206 class MathFunctionScoreExpression : public ScoreExpression { 207 public: 208 enum class FunctionType { 209 kLog, 210 kPow, 211 kMax, 212 kMin, 213 kLen, 214 kSum, 215 kAvg, 216 kSqrt, 217 kAbs, 218 kSin, 219 kCos, 220 kTan, 221 kMaxOrDefault, 222 kMinOrDefault, 223 }; 224 225 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 226 227 static const std::unordered_set<FunctionType> kVariableArgumentsFunctions; 228 229 static const std::unordered_set<FunctionType> kListArgumentFunctions; 230 231 // RETURNS: 232 // - A MathFunctionScoreExpression instance on success if not simplifiable. 233 // - A ConstantScoreExpression instance on success if simplifiable. 234 // - FAILED_PRECONDITION on any null pointer in args. 235 // - INVALID_ARGUMENT on type errors. 236 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 237 FunctionType function_type, 238 std::vector<std::unique_ptr<ScoreExpression>> args); 239 240 libtextclassifier3::StatusOr<double> EvaluateDouble( 241 const DocHitInfo& hit_info, 242 const DocHitInfoIterator* query_it) const override; 243 type()244 ScoreExpressionType type() const override { 245 return ScoreExpressionType::kDouble; 246 } 247 248 private: MathFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)249 explicit MathFunctionScoreExpression( 250 FunctionType function_type, 251 std::vector<std::unique_ptr<ScoreExpression>> args) 252 : function_type_(function_type), args_(std::move(args)) {} 253 254 FunctionType function_type_; 255 std::vector<std::unique_ptr<ScoreExpression>> args_; 256 }; 257 258 class ListOperationFunctionScoreExpression : public ScoreExpression { 259 public: 260 enum class FunctionType { kFilterByRange }; 261 262 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 263 264 // RETURNS: 265 // - A ListOperationFunctionScoreExpression instance on success. 266 // - FAILED_PRECONDITION on any null pointer in args. 267 // - INVALID_ARGUMENT on type errors. 268 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 269 FunctionType function_type, 270 std::vector<std::unique_ptr<ScoreExpression>> args); 271 272 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 273 const DocHitInfo& hit_info, 274 const DocHitInfoIterator* query_it) const override; 275 type()276 ScoreExpressionType type() const override { 277 return ScoreExpressionType::kDoubleList; 278 } 279 280 private: ListOperationFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args)281 explicit ListOperationFunctionScoreExpression( 282 FunctionType function_type, 283 std::vector<std::unique_ptr<ScoreExpression>> args) 284 : function_type_(function_type), args_(std::move(args)) {} 285 286 FunctionType function_type_; 287 std::vector<std::unique_ptr<ScoreExpression>> args_; 288 }; 289 290 class DocumentFunctionScoreExpression : public ScoreExpression { 291 public: 292 enum class FunctionType { 293 kDocumentScore, 294 kCreationTimestamp, 295 kUsageCount, 296 kUsageLastUsedTimestamp, 297 }; 298 299 static const std::unordered_map<std::string, FunctionType> kFunctionNames; 300 301 // RETURNS: 302 // - A DocumentFunctionScoreExpression instance on success. 303 // - FAILED_PRECONDITION on any null pointer in args. 304 // - INVALID_ARGUMENT on type errors. 305 static libtextclassifier3::StatusOr< 306 std::unique_ptr<DocumentFunctionScoreExpression>> 307 Create(FunctionType function_type, 308 std::vector<std::unique_ptr<ScoreExpression>> args, 309 const DocumentStore* document_store, double default_score, 310 int64_t current_time_ms); 311 312 libtextclassifier3::StatusOr<double> EvaluateDouble( 313 const DocHitInfo& hit_info, 314 const DocHitInfoIterator* query_it) const override; 315 type()316 ScoreExpressionType type() const override { 317 return ScoreExpressionType::kDouble; 318 } 319 320 private: DocumentFunctionScoreExpression(FunctionType function_type,std::vector<std::unique_ptr<ScoreExpression>> args,const DocumentStore * document_store,double default_score,int64_t current_time_ms)321 explicit DocumentFunctionScoreExpression( 322 FunctionType function_type, 323 std::vector<std::unique_ptr<ScoreExpression>> args, 324 const DocumentStore* document_store, double default_score, 325 int64_t current_time_ms) 326 : args_(std::move(args)), 327 document_store_(*document_store), 328 default_score_(default_score), 329 function_type_(function_type), 330 current_time_ms_(current_time_ms) {} 331 332 std::vector<std::unique_ptr<ScoreExpression>> args_; 333 const DocumentStore& document_store_; 334 double default_score_; 335 FunctionType function_type_; 336 int64_t current_time_ms_; 337 }; 338 339 class RelevanceScoreFunctionScoreExpression : public ScoreExpression { 340 public: 341 static constexpr std::string_view kFunctionName = "relevanceScore"; 342 343 // RETURNS: 344 // - A RelevanceScoreFunctionScoreExpression instance on success. 345 // - FAILED_PRECONDITION on any null pointer in args. 346 // - INVALID_ARGUMENT on type errors. 347 static libtextclassifier3::StatusOr< 348 std::unique_ptr<RelevanceScoreFunctionScoreExpression>> 349 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 350 Bm25fCalculator* bm25f_calculator, double default_score); 351 352 libtextclassifier3::StatusOr<double> EvaluateDouble( 353 const DocHitInfo& hit_info, 354 const DocHitInfoIterator* query_it) const override; 355 type()356 ScoreExpressionType type() const override { 357 return ScoreExpressionType::kDouble; 358 } 359 360 private: RelevanceScoreFunctionScoreExpression(Bm25fCalculator * bm25f_calculator,double default_score)361 explicit RelevanceScoreFunctionScoreExpression( 362 Bm25fCalculator* bm25f_calculator, double default_score) 363 : bm25f_calculator_(*bm25f_calculator), default_score_(default_score) {} 364 365 Bm25fCalculator& bm25f_calculator_; 366 double default_score_; 367 }; 368 369 class ChildrenRankingSignalsFunctionScoreExpression : public ScoreExpression { 370 public: 371 static constexpr std::string_view kFunctionName = "childrenRankingSignals"; 372 373 // RETURNS: 374 // - A ChildrenRankingSignalsFunctionScoreExpression instance on success. 375 // - FAILED_PRECONDITION on any null pointer in children. 376 // - INVALID_ARGUMENT on type errors. 377 static libtextclassifier3::StatusOr< 378 std::unique_ptr<ChildrenRankingSignalsFunctionScoreExpression>> 379 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 380 const DocumentStore& document_store, 381 const JoinChildrenFetcher* join_children_fetcher, 382 int64_t current_time_ms); 383 384 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 385 const DocHitInfo& hit_info, 386 const DocHitInfoIterator* query_it) const override; 387 type()388 ScoreExpressionType type() const override { 389 return ScoreExpressionType::kDoubleList; 390 } 391 392 private: ChildrenRankingSignalsFunctionScoreExpression(const DocumentStore & document_store,const JoinChildrenFetcher & join_children_fetcher,int64_t current_time_ms)393 explicit ChildrenRankingSignalsFunctionScoreExpression( 394 const DocumentStore& document_store, 395 const JoinChildrenFetcher& join_children_fetcher, int64_t current_time_ms) 396 : document_store_(document_store), 397 join_children_fetcher_(join_children_fetcher), 398 current_time_ms_(current_time_ms) {} 399 400 const DocumentStore& document_store_; // Does not own. 401 const JoinChildrenFetcher& join_children_fetcher_; // Does not own. 402 int64_t current_time_ms_; 403 }; 404 405 class PropertyWeightsFunctionScoreExpression : public ScoreExpression { 406 public: 407 static constexpr std::string_view kFunctionName = "propertyWeights"; 408 409 // RETURNS: 410 // - A PropertyWeightsFunctionScoreExpression instance on success. 411 // - FAILED_PRECONDITION on any null pointer in children. 412 // - INVALID_ARGUMENT on type errors. 413 static libtextclassifier3::StatusOr< 414 std::unique_ptr<PropertyWeightsFunctionScoreExpression>> 415 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 416 const DocumentStore* document_store, 417 const SectionWeights* section_weights, int64_t current_time_ms); 418 419 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 420 const DocHitInfo& hit_info, const DocHitInfoIterator*) const override; 421 type()422 ScoreExpressionType type() const override { 423 return ScoreExpressionType::kDoubleList; 424 } 425 426 private: PropertyWeightsFunctionScoreExpression(const DocumentStore * document_store,const SectionWeights * section_weights,int64_t current_time_ms)427 explicit PropertyWeightsFunctionScoreExpression( 428 const DocumentStore* document_store, 429 const SectionWeights* section_weights, int64_t current_time_ms) 430 : document_store_(*document_store), 431 section_weights_(*section_weights), 432 current_time_ms_(current_time_ms) {} 433 const DocumentStore& document_store_; 434 const SectionWeights& section_weights_; 435 int64_t current_time_ms_; 436 }; 437 438 class GetEmbeddingParameterFunctionScoreExpression : public ScoreExpression { 439 public: 440 static constexpr std::string_view kFunctionName = "getEmbeddingParameter"; 441 442 // RETURNS: 443 // - A GetEmbeddingParameterFunctionScoreExpression instance on success if 444 // not simplifiable. 445 // - A ConstantScoreExpression instance on success if simplifiable. 446 // - FAILED_PRECONDITION on any null pointer in children. 447 // - INVALID_ARGUMENT on type errors. 448 static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( 449 std::vector<std::unique_ptr<ScoreExpression>> args); 450 451 libtextclassifier3::StatusOr<double> EvaluateDouble( 452 const DocHitInfo& hit_info, 453 const DocHitInfoIterator* query_it) const override; 454 type()455 ScoreExpressionType type() const override { 456 return ScoreExpressionType::kVectorIndex; 457 } 458 459 private: GetEmbeddingParameterFunctionScoreExpression(std::unique_ptr<ScoreExpression> arg)460 explicit GetEmbeddingParameterFunctionScoreExpression( 461 std::unique_ptr<ScoreExpression> arg) 462 : arg_(std::move(arg)) {} 463 std::unique_ptr<ScoreExpression> arg_; 464 }; 465 466 class MatchedSemanticScoresFunctionScoreExpression : public ScoreExpression { 467 public: 468 static constexpr std::string_view kFunctionName = "matchedSemanticScores"; 469 470 // RETURNS: 471 // - A MatchedSemanticScoresFunctionScoreExpression instance on success. 472 // - FAILED_PRECONDITION on any null pointer in children. 473 // - INVALID_ARGUMENT on type errors. 474 static libtextclassifier3::StatusOr< 475 std::unique_ptr<MatchedSemanticScoresFunctionScoreExpression>> 476 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 477 SearchSpecProto::EmbeddingQueryMetricType::Code default_metric_type, 478 const EmbeddingQueryResults* embedding_query_results); 479 480 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 481 const DocHitInfo& hit_info, 482 const DocHitInfoIterator* query_it) const override; 483 type()484 ScoreExpressionType type() const override { 485 return ScoreExpressionType::kDoubleList; 486 } 487 488 private: MatchedSemanticScoresFunctionScoreExpression(std::vector<std::unique_ptr<ScoreExpression>> args,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,const EmbeddingQueryResults & embedding_query_results)489 explicit MatchedSemanticScoresFunctionScoreExpression( 490 std::vector<std::unique_ptr<ScoreExpression>> args, 491 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 492 const EmbeddingQueryResults& embedding_query_results) 493 : args_(std::move(args)), 494 metric_type_(metric_type), 495 embedding_query_results_(embedding_query_results) {} 496 497 std::vector<std::unique_ptr<ScoreExpression>> args_; 498 const SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; 499 const EmbeddingQueryResults& embedding_query_results_; 500 }; 501 502 class GetScorablePropertyFunctionScoreExpression : public ScoreExpression { 503 public: 504 static constexpr std::string_view kFunctionName = "getScorableProperty"; 505 506 // Returns: 507 // - FAILED_PRECONDITION on any null pointer in children. 508 // - INVALID_ARGUMENT on 509 // - |args| type errors. 510 // - alias_schema_type in the scoring expression does not match to any 511 // schema type in the |schema_type_alias_map|. 512 // - any matched schema type not having the specified property_path as a 513 // scorable property. 514 static libtextclassifier3::StatusOr< 515 std::unique_ptr<GetScorablePropertyFunctionScoreExpression>> 516 Create(std::vector<std::unique_ptr<ScoreExpression>> args, 517 const DocumentStore* document_store, const SchemaStore* schema_store, 518 const SchemaTypeAliasMap& schema_type_alias_map, 519 int64_t current_time_ms); 520 type()521 ScoreExpressionType type() const override { 522 return ScoreExpressionType::kDoubleList; 523 } 524 525 libtextclassifier3::StatusOr<std::vector<double>> EvaluateList( 526 const DocHitInfo& hit_info, 527 const DocHitInfoIterator* query_it) const override; 528 529 private: 530 explicit GetScorablePropertyFunctionScoreExpression( 531 const DocumentStore* document_store, const SchemaStore* schema_store, 532 int64_t current_time_ms, 533 std::unordered_set<SchemaTypeId>&& schema_type_ids, 534 std::string_view property_path); 535 536 // Returns a set of schema type ids that are matched to the 537 // |alias_schema_type| in the scoring expression, based on the 538 // |schema_type_alias_map|. 539 // 540 // For each of the schema type in the returned set, this function also 541 // validates that: 542 // - The schema type is valid in the schema store. 543 // - The |property_path| is defined as scorable under the schema type. 544 static libtextclassifier3::StatusOr<std::unordered_set<SchemaTypeId>> 545 GetAndValidateSchemaTypeIds(std::string_view alias_schema_type, 546 std::string_view property_path, 547 const SchemaTypeAliasMap& schema_type_alias_map, 548 const SchemaStore& schema_store); 549 550 const DocumentStore& document_store_; 551 const SchemaStore& schema_store_; 552 int64_t current_time_ms_; 553 // A doc hit is evaluated by this function only if its schema type id is in 554 // this set. 555 std::unordered_set<SchemaTypeId> schema_type_ids_; 556 std::string property_path_; 557 }; 558 559 } // namespace lib 560 } // namespace icing 561 562 #endif // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ 563