xref: /aosp_15_r20/external/icing/icing/index/embed/embedding-scorer.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/index/embed/embedding-scorer.h"
16 
17 #include <cmath>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/absl_ports/str_cat.h"
26 #include "icing/index/embed/quantizer.h"
27 #include "icing/proto/search.pb.h"
28 
29 namespace icing {
30 namespace lib {
31 
32 namespace {
33 
34 template <typename T>
ToFloat(T value)35 inline std::enable_if_t<std::is_same<T, float>::value, float> ToFloat(T value) {
36   return value;
37 }
38 
39 template <typename T>
ToFloat(T value,const Quantizer &)40 inline std::enable_if_t<std::is_same<T, float>::value, float> ToFloat(
41     T value, const Quantizer&) {
42   return value;
43 }
44 
45 template <typename T>
ToFloat(T quantized,const Quantizer & quantizer)46 inline std::enable_if_t<std::is_same<T, uint8_t>::value, float> ToFloat(
47     T quantized, const Quantizer& quantizer) {
48   return quantizer.Dequantize(quantized);
49 }
50 
51 template <typename T1, typename T2, typename... Args>
CalculateDotProduct(int dimension,const T1 * v1,const T2 * v2,const Args &...args)52 float CalculateDotProduct(int dimension, const T1* v1, const T2* v2,
53                           const Args&... args) {
54   float dot_product = 0.0;
55   for (int i = 0; i < dimension; ++i) {
56     dot_product += ToFloat(v1[i], args...) * ToFloat(v2[i], args...);
57   }
58   return dot_product;
59 }
60 
61 template <typename T, typename... Args>
CalculateNorm2(int dimension,const T * v,const Args &...args)62 float CalculateNorm2(int dimension, const T* v, const Args&... args) {
63   return std::sqrt(CalculateDotProduct(dimension, v, v, args...));
64 }
65 
66 template <typename T1, typename T2, typename... Args>
CalculateCosine(int dimension,const T1 * v1,const T2 * v2,const Args &...args)67 float CalculateCosine(int dimension, const T1* v1, const T2* v2,
68                       const Args&... args) {
69   float divisor = CalculateNorm2(dimension, v1, args...) *
70                   CalculateNorm2(dimension, v2, args...);
71   if (divisor == 0.0) {
72     return 0.0;
73   }
74   return CalculateDotProduct(dimension, v1, v2, args...) / divisor;
75 }
76 
77 template <typename T1, typename T2, typename... Args>
CalculateEuclideanDistance(int dimension,const T1 * v1,const T2 * v2,const Args &...args)78 float CalculateEuclideanDistance(int dimension, const T1* v1, const T2* v2,
79                                  const Args&... args) {
80   float result = 0.0;
81   for (int i = 0; i < dimension; ++i) {
82     float diff = ToFloat(v1[i], args...) - ToFloat(v2[i], args...);
83     result += diff * diff;
84   }
85   return std::sqrt(result);
86 }
87 
88 }  // namespace
89 
90 libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>>
Create(SearchSpecProto::EmbeddingQueryMetricType::Code metric_type)91 EmbeddingScorer::Create(
92     SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) {
93   switch (metric_type) {
94     case SearchSpecProto::EmbeddingQueryMetricType::COSINE:
95       return std::make_unique<CosineEmbeddingScorer>();
96     case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT:
97       return std::make_unique<DotProductEmbeddingScorer>();
98     case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN:
99       return std::make_unique<EuclideanDistanceEmbeddingScorer>();
100     default:
101       return absl_ports::InvalidArgumentError(absl_ports::StrCat(
102           "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type)));
103   }
104 }
105 
Score(int dimension,const float * v1,const float * v2) const106 float CosineEmbeddingScorer::Score(int dimension, const float* v1,
107                                    const float* v2) const {
108   return CalculateCosine(dimension, v1, v2);
109 }
110 
Score(int dimension,const float * v1,const float * v2) const111 float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
112                                        const float* v2) const {
113   return CalculateDotProduct(dimension, v1, v2);
114 }
115 
Score(int dimension,const float * v1,const float * v2) const116 float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1,
117                                               const float* v2) const {
118   return CalculateEuclideanDistance(dimension, v1, v2);
119 }
120 
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const121 float CosineEmbeddingScorer::Score(int dimension, const float* v1,
122                                    const uint8_t* v2,
123                                    const Quantizer& quantizer) const {
124   return CalculateCosine(dimension, v1, v2, quantizer);
125 }
126 
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const127 float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
128                                        const uint8_t* v2,
129                                        const Quantizer& quantizer) const {
130   return CalculateDotProduct(dimension, v1, v2, quantizer);
131 }
132 
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const133 float EuclideanDistanceEmbeddingScorer::Score(
134     int dimension, const float* v1, const uint8_t* v2,
135     const Quantizer& quantizer) const {
136   return CalculateEuclideanDistance(dimension, v1, v2, quantizer);
137 }
138 
139 }  // namespace lib
140 }  // namespace icing
141