1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/index/embed/embedding-scorer.h"
16
17 #include <cmath>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/absl_ports/str_cat.h"
26 #include "icing/index/embed/quantizer.h"
27 #include "icing/proto/search.pb.h"
28
29 namespace icing {
30 namespace lib {
31
32 namespace {
33
34 template <typename T>
ToFloat(T value)35 inline std::enable_if_t<std::is_same<T, float>::value, float> ToFloat(T value) {
36 return value;
37 }
38
39 template <typename T>
ToFloat(T value,const Quantizer &)40 inline std::enable_if_t<std::is_same<T, float>::value, float> ToFloat(
41 T value, const Quantizer&) {
42 return value;
43 }
44
45 template <typename T>
ToFloat(T quantized,const Quantizer & quantizer)46 inline std::enable_if_t<std::is_same<T, uint8_t>::value, float> ToFloat(
47 T quantized, const Quantizer& quantizer) {
48 return quantizer.Dequantize(quantized);
49 }
50
51 template <typename T1, typename T2, typename... Args>
CalculateDotProduct(int dimension,const T1 * v1,const T2 * v2,const Args &...args)52 float CalculateDotProduct(int dimension, const T1* v1, const T2* v2,
53 const Args&... args) {
54 float dot_product = 0.0;
55 for (int i = 0; i < dimension; ++i) {
56 dot_product += ToFloat(v1[i], args...) * ToFloat(v2[i], args...);
57 }
58 return dot_product;
59 }
60
61 template <typename T, typename... Args>
CalculateNorm2(int dimension,const T * v,const Args &...args)62 float CalculateNorm2(int dimension, const T* v, const Args&... args) {
63 return std::sqrt(CalculateDotProduct(dimension, v, v, args...));
64 }
65
66 template <typename T1, typename T2, typename... Args>
CalculateCosine(int dimension,const T1 * v1,const T2 * v2,const Args &...args)67 float CalculateCosine(int dimension, const T1* v1, const T2* v2,
68 const Args&... args) {
69 float divisor = CalculateNorm2(dimension, v1, args...) *
70 CalculateNorm2(dimension, v2, args...);
71 if (divisor == 0.0) {
72 return 0.0;
73 }
74 return CalculateDotProduct(dimension, v1, v2, args...) / divisor;
75 }
76
77 template <typename T1, typename T2, typename... Args>
CalculateEuclideanDistance(int dimension,const T1 * v1,const T2 * v2,const Args &...args)78 float CalculateEuclideanDistance(int dimension, const T1* v1, const T2* v2,
79 const Args&... args) {
80 float result = 0.0;
81 for (int i = 0; i < dimension; ++i) {
82 float diff = ToFloat(v1[i], args...) - ToFloat(v2[i], args...);
83 result += diff * diff;
84 }
85 return std::sqrt(result);
86 }
87
88 } // namespace
89
90 libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingScorer>>
Create(SearchSpecProto::EmbeddingQueryMetricType::Code metric_type)91 EmbeddingScorer::Create(
92 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) {
93 switch (metric_type) {
94 case SearchSpecProto::EmbeddingQueryMetricType::COSINE:
95 return std::make_unique<CosineEmbeddingScorer>();
96 case SearchSpecProto::EmbeddingQueryMetricType::DOT_PRODUCT:
97 return std::make_unique<DotProductEmbeddingScorer>();
98 case SearchSpecProto::EmbeddingQueryMetricType::EUCLIDEAN:
99 return std::make_unique<EuclideanDistanceEmbeddingScorer>();
100 default:
101 return absl_ports::InvalidArgumentError(absl_ports::StrCat(
102 "Invalid EmbeddingQueryMetricType: ", std::to_string(metric_type)));
103 }
104 }
105
Score(int dimension,const float * v1,const float * v2) const106 float CosineEmbeddingScorer::Score(int dimension, const float* v1,
107 const float* v2) const {
108 return CalculateCosine(dimension, v1, v2);
109 }
110
Score(int dimension,const float * v1,const float * v2) const111 float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
112 const float* v2) const {
113 return CalculateDotProduct(dimension, v1, v2);
114 }
115
Score(int dimension,const float * v1,const float * v2) const116 float EuclideanDistanceEmbeddingScorer::Score(int dimension, const float* v1,
117 const float* v2) const {
118 return CalculateEuclideanDistance(dimension, v1, v2);
119 }
120
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const121 float CosineEmbeddingScorer::Score(int dimension, const float* v1,
122 const uint8_t* v2,
123 const Quantizer& quantizer) const {
124 return CalculateCosine(dimension, v1, v2, quantizer);
125 }
126
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const127 float DotProductEmbeddingScorer::Score(int dimension, const float* v1,
128 const uint8_t* v2,
129 const Quantizer& quantizer) const {
130 return CalculateDotProduct(dimension, v1, v2, quantizer);
131 }
132
Score(int dimension,const float * v1,const uint8_t * v2,const Quantizer & quantizer) const133 float EuclideanDistanceEmbeddingScorer::Score(
134 int dimension, const float* v1, const uint8_t* v2,
135 const Quantizer& quantizer) const {
136 return CalculateEuclideanDistance(dimension, v1, v2, quantizer);
137 }
138
139 } // namespace lib
140 } // namespace icing
141