xref: /aosp_15_r20/external/icing/icing/testing/embedding-test-utils.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_TESTING_EMBEDDING_TEST_UTILS_H_
16 #define ICING_TESTING_EMBEDDING_TEST_UTILS_H_
17 
18 #include <cstdint>
19 #include <initializer_list>
20 #include <string>
21 #include <string_view>
22 #include <vector>
23 
24 #include "icing/text_classifier/lib3/utils/base/statusor.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/index/embed/embedding-index.h"
27 #include "icing/proto/document.pb.h"
28 
29 namespace icing {
30 namespace lib {
31 
CreateVector(const std::string & model_signature,std::initializer_list<float> values)32 inline PropertyProto::VectorProto CreateVector(
33     const std::string& model_signature, std::initializer_list<float> values) {
34   PropertyProto::VectorProto vector;
35   vector.set_model_signature(model_signature);
36   for (float value : values) {
37     vector.add_values(value);
38   }
39   return vector;
40 }
41 
42 template <typename... V>
CreateVector(const std::string & model_signature,V &&...values)43 inline PropertyProto::VectorProto CreateVector(
44     const std::string& model_signature, V&&... values) {
45   return CreateVector(model_signature, values...);
46 }
47 
48 libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
49 GetEmbeddingHitsFromIndex(const EmbeddingIndex* embedding_index,
50                           uint32_t dimension, std::string_view model_signature);
51 
52 std::vector<float> GetRawEmbeddingDataFromIndex(
53     const EmbeddingIndex* embedding_index);
54 
55 // Gets the quantized embedding vector from the index based on the given hit,
56 // and returns the dequantized version of the vector.
57 libtextclassifier3::StatusOr<std::vector<float>>
58 GetAndRestoreQuantizedEmbeddingVectorFromIndex(
59     const EmbeddingIndex* embedding_index, const EmbeddingHit& hit,
60     uint32_t dimension);
61 
62 }  // namespace lib
63 }  // namespace icing
64 
65 #endif  // ICING_TESTING_EMBEDDING_TEST_UTILS_H_
66