1 // Copyright (C) 2024 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_UTIL_EMBEDDING_UTIL_H_ 16 #define ICING_UTIL_EMBEDDING_UTIL_H_ 17 18 #include <string_view> 19 20 #include "icing/text_classifier/lib3/utils/base/statusor.h" 21 #include "icing/absl_ports/canonical_errors.h" 22 #include "icing/absl_ports/str_cat.h" 23 #include "icing/proto/search.pb.h" 24 25 namespace icing { 26 namespace lib { 27 28 namespace embedding_util { 29 30 inline libtextclassifier3::StatusOr< 31 SearchSpecProto::EmbeddingQueryMetricType::Code> GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name)32GetEmbeddingQueryMetricTypeFromName(std::string_view metric_name) { 33 if (metric_name == "COSINE") { 34 return SearchSpecProto::EmbeddingQueryMetricType::COSINE; 35 } else if (metric_name == "DOT_PRODUCT") { 36 return SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT; 37 } else if (metric_name == "EUCLIDEAN") { 38 return SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN; 39 } 40 return absl_ports::InvalidArgumentError( 41 absl_ports::StrCat("Unknown metric type: ", metric_name)); 42 } 43 44 } // namespace embedding_util 45 46 } // namespace lib 47 } // namespace icing 48 49 #endif // ICING_UTIL_EMBEDDING_UTIL_H_ 50