xref: /aosp_15_r20/external/icing/icing/index/embedding-indexing-handler_test.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/index/embedding-indexing-handler.h"
16 
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <string_view>
21 #include <utility>
22 
23 #include "icing/text_classifier/lib3/utils/base/status.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "icing/document-builder.h"
27 #include "icing/feature-flags.h"
28 #include "icing/file/filesystem.h"
29 #include "icing/file/portable-file-backed-proto-log.h"
30 #include "icing/index/embed/embedding-hit.h"
31 #include "icing/index/embed/embedding-index.h"
32 #include "icing/index/embed/quantizer.h"
33 #include "icing/index/hit/hit.h"
34 #include "icing/portable/platform.h"
35 #include "icing/proto/document_wrapper.pb.h"
36 #include "icing/proto/schema.pb.h"
37 #include "icing/schema-builder.h"
38 #include "icing/schema/schema-store.h"
39 #include "icing/schema/section.h"
40 #include "icing/store/document-id.h"
41 #include "icing/store/document-store.h"
42 #include "icing/testing/common-matchers.h"
43 #include "icing/testing/embedding-test-utils.h"
44 #include "icing/testing/fake-clock.h"
45 #include "icing/testing/test-data.h"
46 #include "icing/testing/test-feature-flags.h"
47 #include "icing/testing/tmp-directory.h"
48 #include "icing/tokenization/language-segmenter-factory.h"
49 #include "icing/tokenization/language-segmenter.h"
50 #include "icing/util/icu-data-file-helper.h"
51 #include "icing/util/tokenized-document.h"
52 #include "unicode/uloc.h"
53 
54 namespace icing {
55 namespace lib {
56 
57 namespace {
58 
59 using ::testing::ElementsAre;
60 using ::testing::Eq;
61 using ::testing::FloatNear;
62 using ::testing::IsEmpty;
63 using ::testing::IsTrue;
64 using ::testing::Pointwise;
65 
66 // Indexable properties (section) and section id. Section id is determined by
67 // the lexicographical order of indexable property paths.
68 // Schema type with indexable properties: FakeType
69 // Section id = 0: "body"
70 // Section id = 1: "bodyEmbedding"
71 // Section id = 2: "quantizedEmbedding"
72 // Section id = 3: "title"
73 // Section id = 4: "titleEmbedding"
74 static constexpr std::string_view kFakeType = "FakeType";
75 static constexpr std::string_view kPropertyBody = "body";
76 static constexpr std::string_view kPropertyBodyEmbedding = "bodyEmbedding";
77 static constexpr std::string_view kPropertyQuantizedEmbedding =
78     "quantizedEmbedding";
79 static constexpr std::string_view kPropertyTitle = "title";
80 static constexpr std::string_view kPropertyTitleEmbedding = "titleEmbedding";
81 static constexpr std::string_view kPropertyNonIndexableEmbedding =
82     "nonIndexableEmbedding";
83 
84 static constexpr SectionId kSectionIdBodyEmbedding = 1;
85 static constexpr SectionId kSectionIdQuantizedEmbedding = 2;
86 static constexpr SectionId kSectionIdTitleEmbedding = 4;
87 
88 // Schema type with nested indexable properties: FakeCollectionType
89 // Section id = 0: "collection.body"
90 // Section id = 1: "collection.bodyEmbedding"
91 // Section id = 2: "collection.quantizedEmbedding"
92 // Section id = 3: "collection.title"
93 // Section id = 4: "collection.titleEmbedding"
94 // Section id = 5: "fullDocEmbedding"
95 static constexpr std::string_view kFakeCollectionType = "FakeCollectionType";
96 static constexpr std::string_view kPropertyCollection = "collection";
97 static constexpr std::string_view kPropertyFullDocEmbedding =
98     "fullDocEmbedding";
99 
100 static constexpr SectionId kSectionIdNestedBodyEmbedding = 1;
101 static constexpr SectionId kSectionIdNestedQuantizedEmbedding = 2;
102 static constexpr SectionId kSectionIdNestedTitleEmbedding = 4;
103 static constexpr SectionId kSectionIdFullDocEmbedding = 5;
104 
105 constexpr float kEpsQuantized = 0.01f;
106 
107 class EmbeddingIndexingHandlerTest : public ::testing::Test {
108  protected:
SetUp()109   void SetUp() override {
110     feature_flags_ = std::make_unique<FeatureFlags>(GetTestFeatureFlags());
111     if (!IsCfStringTokenization() && !IsReverseJniTokenization()) {
112       ICING_ASSERT_OK(
113           // File generated via icu_data_file rule in //icing/BUILD.
114           icu_data_file_helper::SetUpIcuDataFile(
115               GetTestFilePath("icing/icu.dat")));
116     }
117 
118     base_dir_ = GetTestTempDir() + "/icing_test";
119     ASSERT_THAT(filesystem_.CreateDirectoryRecursively(base_dir_.c_str()),
120                 IsTrue());
121 
122     embedding_index_working_path_ = base_dir_ + "/embedding_index";
123     schema_store_dir_ = base_dir_ + "/schema_store";
124     document_store_dir_ = base_dir_ + "/document_store";
125 
126     language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US);
127     ICING_ASSERT_OK_AND_ASSIGN(
128         lang_segmenter_,
129         language_segmenter_factory::Create(std::move(segmenter_options)));
130 
131     ASSERT_THAT(
132         filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()),
133         IsTrue());
134     ICING_ASSERT_OK_AND_ASSIGN(
135         schema_store_, SchemaStore::Create(&filesystem_, schema_store_dir_,
136                                            &fake_clock_, feature_flags_.get()));
137     SchemaProto schema =
138         SchemaBuilder()
139             .AddType(
140                 SchemaTypeConfigBuilder()
141                     .SetType(kFakeType)
142                     .AddProperty(PropertyConfigBuilder()
143                                      .SetName(kPropertyTitle)
144                                      .SetDataTypeString(TERM_MATCH_EXACT,
145                                                         TOKENIZER_PLAIN)
146                                      .SetCardinality(CARDINALITY_OPTIONAL))
147                     .AddProperty(PropertyConfigBuilder()
148                                      .SetName(kPropertyBody)
149                                      .SetDataTypeString(TERM_MATCH_EXACT,
150                                                         TOKENIZER_PLAIN)
151                                      .SetCardinality(CARDINALITY_REPEATED))
152                     .AddProperty(
153                         PropertyConfigBuilder()
154                             .SetName(kPropertyTitleEmbedding)
155                             .SetDataTypeVector(
156                                 EmbeddingIndexingConfig::EmbeddingIndexingType::
157                                     LINEAR_SEARCH)
158                             .SetCardinality(CARDINALITY_OPTIONAL))
159                     .AddProperty(
160                         PropertyConfigBuilder()
161                             .SetName(kPropertyBodyEmbedding)
162                             .SetDataTypeVector(
163                                 EmbeddingIndexingConfig::EmbeddingIndexingType::
164                                     LINEAR_SEARCH)
165                             .SetCardinality(CARDINALITY_REPEATED))
166                     .AddProperty(
167                         PropertyConfigBuilder()
168                             .SetName(kPropertyQuantizedEmbedding)
169                             .SetDataTypeVector(
170                                 EmbeddingIndexingConfig::EmbeddingIndexingType::
171                                     LINEAR_SEARCH,
172                                 QUANTIZATION_TYPE_QUANTIZE_8_BIT)
173                             .SetCardinality(CARDINALITY_REPEATED))
174                     .AddProperty(PropertyConfigBuilder()
175                                      .SetName(kPropertyNonIndexableEmbedding)
176                                      .SetDataType(TYPE_VECTOR)
177                                      .SetCardinality(CARDINALITY_REPEATED)))
178             .AddType(SchemaTypeConfigBuilder()
179                          .SetType(kFakeCollectionType)
180                          .AddProperty(PropertyConfigBuilder()
181                                           .SetName(kPropertyCollection)
182                                           .SetDataTypeDocument(
183                                               kFakeType,
184                                               /*index_nested_properties=*/true)
185                                           .SetCardinality(CARDINALITY_REPEATED))
186                          .AddProperty(
187                              PropertyConfigBuilder()
188                                  .SetName(kPropertyFullDocEmbedding)
189                                  .SetDataTypeVector(
190                                      EmbeddingIndexingConfig::
191                                          EmbeddingIndexingType::LINEAR_SEARCH)
192                                  .SetCardinality(CARDINALITY_OPTIONAL)))
193             .Build();
194     ICING_ASSERT_OK(schema_store_->SetSchema(
195         schema, /*ignore_errors_and_delete_documents=*/false,
196         /*allow_circular_schema_definitions=*/false));
197 
198     ASSERT_TRUE(
199         filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str()));
200     ICING_ASSERT_OK_AND_ASSIGN(
201         DocumentStore::CreateResult doc_store_create_result,
202         DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_,
203                               schema_store_.get(), feature_flags_.get(),
204                               /*force_recovery_and_revalidate_documents=*/false,
205                               /*pre_mapping_fbv=*/false,
206                               /*use_persistent_hash_map=*/true,
207                               PortableFileBackedProtoLog<
208                                   DocumentWrapper>::kDefaultCompressionLevel,
209                               /*initialize_stats=*/nullptr));
210     document_store_ = std::move(doc_store_create_result.document_store);
211 
212     ICING_ASSERT_OK_AND_ASSIGN(
213         embedding_index_,
214         EmbeddingIndex::Create(&filesystem_, embedding_index_working_path_,
215                                &fake_clock_, feature_flags_.get()));
216   }
217 
TearDown()218   void TearDown() override {
219     document_store_.reset();
220     schema_store_.reset();
221     lang_segmenter_.reset();
222     embedding_index_.reset();
223 
224     filesystem_.DeleteDirectoryRecursively(base_dir_.c_str());
225   }
226 
227   std::unique_ptr<FeatureFlags> feature_flags_;
228   Filesystem filesystem_;
229   FakeClock fake_clock_;
230   std::string base_dir_;
231   std::string embedding_index_working_path_;
232   std::string schema_store_dir_;
233   std::string document_store_dir_;
234 
235   std::unique_ptr<EmbeddingIndex> embedding_index_;
236   std::unique_ptr<LanguageSegmenter> lang_segmenter_;
237   std::unique_ptr<SchemaStore> schema_store_;
238   std::unique_ptr<DocumentStore> document_store_;
239 };
240 
241 }  // namespace
242 
TEST_F(EmbeddingIndexingHandlerTest,CreationWithNullPointerShouldFail)243 TEST_F(EmbeddingIndexingHandlerTest, CreationWithNullPointerShouldFail) {
244   EXPECT_THAT(EmbeddingIndexingHandler::Create(/*clock=*/nullptr,
245                                                embedding_index_.get(),
246                                                /*enable_embedding_index=*/true),
247               StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
248 
249   EXPECT_THAT(EmbeddingIndexingHandler::Create(&fake_clock_,
250                                                /*embedding_index=*/nullptr,
251                                                /*enable_embedding_index=*/true),
252               StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
253 }
254 
TEST_F(EmbeddingIndexingHandlerTest,HandleEmbeddingSection)255 TEST_F(EmbeddingIndexingHandlerTest, HandleEmbeddingSection) {
256   DocumentProto document =
257       DocumentBuilder()
258           .SetKey("icing", "fake_type/1")
259           .SetSchema(std::string(kFakeType))
260           .AddStringProperty(std::string(kPropertyTitle), "title")
261           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
262                              CreateVector("model", {0.1, 0.2, 0.3}))
263           .AddStringProperty(std::string(kPropertyBody), "body")
264           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
265                              CreateVector("model", {0.4, 0.5, 0.6}),
266                              CreateVector("model", {0.7, 0.8, 0.9}))
267           .AddVectorProperty(std::string(kPropertyQuantizedEmbedding),
268                              CreateVector("model", {0.1, 0.2, 0.3}),
269                              CreateVector("model", {0.4, 0.5, 0.6}))
270           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
271                              CreateVector("model", {1.1, 1.2, 1.3}))
272           .Build();
273   ICING_ASSERT_OK_AND_ASSIGN(
274       TokenizedDocument tokenized_document,
275       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
276                                 std::move(document)));
277   ICING_ASSERT_OK_AND_ASSIGN(
278       DocumentStore::PutResult put_result,
279       document_store_->Put(tokenized_document.document()));
280   DocumentId document_id = put_result.new_document_id;
281 
282   ASSERT_THAT(embedding_index_->last_added_document_id(),
283               Eq(kInvalidDocumentId));
284   // Handle document.
285   ICING_ASSERT_OK_AND_ASSIGN(
286       std::unique_ptr<EmbeddingIndexingHandler> handler,
287       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
288                                        /*enable_embedding_index=*/true));
289   EXPECT_THAT(handler->Handle(
290                   tokenized_document, document_id, put_result.old_document_id,
291                   /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
292               IsOk());
293 
294   // Check index
295   EmbeddingHit hit1(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
296                     /*location=*/0);
297   EmbeddingHit hit2(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
298                     /*location=*/3);
299   EmbeddingHit hit3(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
300                     /*location=*/6);
301   // Quantized embeddings are stored in a different location from unquantized
302   // embeddings, so the location starts from 0 again.
303   EmbeddingHit quantized_hit1(
304       BasicHit(kSectionIdQuantizedEmbedding, /*document_id=*/0),
305       /*location=*/0);
306   EmbeddingHit quantized_hit2(
307       BasicHit(kSectionIdQuantizedEmbedding, /*document_id=*/0),
308       /*location=*/3 + sizeof(Quantizer));
309   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
310                                         /*model_signature=*/"model"),
311               IsOkAndHolds(ElementsAre(hit1, hit2, quantized_hit1,
312                                        quantized_hit2, hit3)));
313   // Check unquantized embedding data
314   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()),
315               ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
316   // Check quantized embedding data
317   EXPECT_THAT(embedding_index_->GetTotalQuantizedVectorSize(),
318               Eq(6 + 2 * sizeof(Quantizer)));
319   EXPECT_THAT(
320       GetAndRestoreQuantizedEmbeddingVectorFromIndex(embedding_index_.get(),
321                                                      quantized_hit1,
322                                                      /*dimension=*/3),
323       IsOkAndHolds(Pointwise(FloatNear(kEpsQuantized), {0.1, 0.2, 0.3})));
324   EXPECT_THAT(
325       GetAndRestoreQuantizedEmbeddingVectorFromIndex(embedding_index_.get(),
326                                                      quantized_hit2,
327                                                      /*dimension=*/3),
328       IsOkAndHolds(Pointwise(FloatNear(kEpsQuantized), {0.4, 0.5, 0.6})));
329 
330   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
331 }
332 
TEST_F(EmbeddingIndexingHandlerTest,EmbeddingShouldNotBeIndexedIfDisabled)333 TEST_F(EmbeddingIndexingHandlerTest, EmbeddingShouldNotBeIndexedIfDisabled) {
334   DocumentProto document =
335       DocumentBuilder()
336           .SetKey("icing", "fake_type/1")
337           .SetSchema(std::string(kFakeType))
338           .AddStringProperty(std::string(kPropertyTitle), "title")
339           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
340                              CreateVector("model", {0.1, 0.2, 0.3}))
341           .AddStringProperty(std::string(kPropertyBody), "body")
342           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
343                              CreateVector("model", {0.4, 0.5, 0.6}),
344                              CreateVector("model", {0.7, 0.8, 0.9}))
345           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
346                              CreateVector("model", {1.1, 1.2, 1.3}))
347           .Build();
348   ICING_ASSERT_OK_AND_ASSIGN(
349       TokenizedDocument tokenized_document,
350       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
351                                 std::move(document)));
352   ICING_ASSERT_OK_AND_ASSIGN(
353       DocumentStore::PutResult put_result,
354       document_store_->Put(tokenized_document.document()));
355   DocumentId document_id = put_result.new_document_id;
356 
357   ASSERT_THAT(embedding_index_->last_added_document_id(),
358               Eq(kInvalidDocumentId));
359   // If enable_embedding_index is false, the handler should not index any
360   // embeddings.
361   ICING_ASSERT_OK_AND_ASSIGN(
362       std::unique_ptr<EmbeddingIndexingHandler> handler,
363       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
364                                        /*enable_embedding_index=*/false));
365   EXPECT_THAT(handler->Handle(
366                   tokenized_document, document_id, put_result.old_document_id,
367                   /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
368               IsOk());
369 
370   // Check that the embedding index is empty.
371   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
372                                         /*model_signature=*/"model"),
373               IsOkAndHolds(IsEmpty()));
374   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()), IsEmpty());
375   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
376 }
377 
TEST_F(EmbeddingIndexingHandlerTest,HandleNestedEmbeddingSection)378 TEST_F(EmbeddingIndexingHandlerTest, HandleNestedEmbeddingSection) {
379   DocumentProto document =
380       DocumentBuilder()
381           .SetKey("icing", "fake_collection_type/1")
382           .SetSchema(std::string(kFakeCollectionType))
383           .AddDocumentProperty(
384               std::string(kPropertyCollection),
385               DocumentBuilder()
386                   .SetKey("icing", "nested_fake_type/1")
387                   .SetSchema(std::string(kFakeType))
388                   .AddStringProperty(std::string(kPropertyTitle), "title")
389                   .AddVectorProperty(std::string(kPropertyTitleEmbedding),
390                                      CreateVector("model", {0.1, 0.2, 0.3}))
391                   .AddStringProperty(std::string(kPropertyBody), "body")
392                   .AddVectorProperty(std::string(kPropertyBodyEmbedding),
393                                      CreateVector("model", {0.4, 0.5, 0.6}),
394                                      CreateVector("model", {0.7, 0.8, 0.9}))
395                   .AddVectorProperty(std::string(kPropertyQuantizedEmbedding),
396                                      CreateVector("model", {0.1, 0.2, 0.3}),
397                                      CreateVector("model", {0.4, 0.5, 0.6}))
398                   .AddVectorProperty(
399                       std::string(kPropertyNonIndexableEmbedding),
400                       CreateVector("model", {1.1, 1.2, 1.3}))
401                   .Build())
402           .AddVectorProperty(std::string(kPropertyFullDocEmbedding),
403                              CreateVector("model", {2.1, 2.2, 2.3}))
404           .Build();
405   ICING_ASSERT_OK_AND_ASSIGN(
406       TokenizedDocument tokenized_document,
407       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
408                                 std::move(document)));
409   ICING_ASSERT_OK_AND_ASSIGN(
410       DocumentStore::PutResult put_result,
411       document_store_->Put(tokenized_document.document()));
412   DocumentId document_id = put_result.new_document_id;
413 
414   ASSERT_THAT(embedding_index_->last_added_document_id(),
415               Eq(kInvalidDocumentId));
416   // Handle document.
417   ICING_ASSERT_OK_AND_ASSIGN(
418       std::unique_ptr<EmbeddingIndexingHandler> handler,
419       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
420                                        /*enable_embedding_index=*/true));
421   EXPECT_THAT(handler->Handle(
422                   tokenized_document, document_id, put_result.old_document_id,
423                   /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
424               IsOk());
425 
426   // Check index
427   EmbeddingHit hit1(BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
428                     /*location=*/0);
429   EmbeddingHit hit2(BasicHit(kSectionIdNestedBodyEmbedding, /*document_id=*/0),
430                     /*location=*/3);
431   EmbeddingHit hit3(BasicHit(kSectionIdNestedTitleEmbedding, /*document_id=*/0),
432                     /*location=*/6);
433   EmbeddingHit hit4(BasicHit(kSectionIdFullDocEmbedding, /*document_id=*/0),
434                     /*location=*/9);
435   // Quantized embeddings are stored in a different location from unquantized
436   // embeddings, so the location starts from 0 again.
437   EmbeddingHit quantized_hit1(
438       BasicHit(kSectionIdNestedQuantizedEmbedding, /*document_id=*/0),
439       /*location=*/0);
440   EmbeddingHit quantized_hit2(
441       BasicHit(kSectionIdNestedQuantizedEmbedding, /*document_id=*/0),
442       /*location=*/3 + sizeof(Quantizer));
443   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
444                                         /*model_signature=*/"model"),
445               IsOkAndHolds(ElementsAre(hit1, hit2, quantized_hit1,
446                                        quantized_hit2, hit3, hit4)));
447   // Check unquantized embedding data
448   EXPECT_THAT(
449       GetRawEmbeddingDataFromIndex(embedding_index_.get()),
450       ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 2.1, 2.2, 2.3));
451   // Check quantized embedding data
452   EXPECT_THAT(embedding_index_->GetTotalQuantizedVectorSize(),
453               Eq(6 + 2 * sizeof(Quantizer)));
454   EXPECT_THAT(
455       GetAndRestoreQuantizedEmbeddingVectorFromIndex(embedding_index_.get(),
456                                                      quantized_hit1,
457                                                      /*dimension=*/3),
458       IsOkAndHolds(Pointwise(FloatNear(kEpsQuantized), {0.1, 0.2, 0.3})));
459   EXPECT_THAT(
460       GetAndRestoreQuantizedEmbeddingVectorFromIndex(embedding_index_.get(),
461                                                      quantized_hit2,
462                                                      /*dimension=*/3),
463       IsOkAndHolds(Pointwise(FloatNear(kEpsQuantized), {0.4, 0.5, 0.6})));
464 
465   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
466 }
467 
TEST_F(EmbeddingIndexingHandlerTest,HandleInvalidNewDocumentIdShouldReturnInvalidArgumentError)468 TEST_F(EmbeddingIndexingHandlerTest,
469        HandleInvalidNewDocumentIdShouldReturnInvalidArgumentError) {
470   DocumentProto document =
471       DocumentBuilder()
472           .SetKey("icing", "fake_type/1")
473           .SetSchema(std::string(kFakeType))
474           .AddStringProperty(std::string(kPropertyTitle), "title")
475           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
476                              CreateVector("model", {0.1, 0.2, 0.3}))
477           .AddStringProperty(std::string(kPropertyBody), "body")
478           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
479                              CreateVector("model", {0.4, 0.5, 0.6}),
480                              CreateVector("model", {0.7, 0.8, 0.9}))
481           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
482                              CreateVector("model", {1.1, 1.2, 1.3}))
483           .Build();
484   ICING_ASSERT_OK_AND_ASSIGN(
485       TokenizedDocument tokenized_document,
486       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
487                                 std::move(document)));
488   ICING_ASSERT_OK(document_store_->Put(tokenized_document.document()));
489 
490   static constexpr DocumentId kCurrentDocumentId = 3;
491   embedding_index_->set_last_added_document_id(kCurrentDocumentId);
492   ASSERT_THAT(embedding_index_->last_added_document_id(),
493               Eq(kCurrentDocumentId));
494 
495   ICING_ASSERT_OK_AND_ASSIGN(
496       std::unique_ptr<EmbeddingIndexingHandler> handler,
497       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
498                                        /*enable_embedding_index=*/true));
499 
500   // Handling document with kInvalidDocumentId should cause a failure, and both
501   // index data and last_added_document_id should remain unchanged.
502   EXPECT_THAT(
503       handler->Handle(tokenized_document, kInvalidDocumentId,
504                       /*old_document_id=*/kInvalidDocumentId,
505                       /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
506       StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
507   EXPECT_THAT(embedding_index_->last_added_document_id(),
508               Eq(kCurrentDocumentId));
509   // Check that the embedding index should be empty
510   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
511                                         /*model_signature=*/"model"),
512               IsOkAndHolds(IsEmpty()));
513   EXPECT_TRUE(embedding_index_->is_empty());
514   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()), IsEmpty());
515 
516   // Recovery mode should get the same result.
517   EXPECT_THAT(
518       handler->Handle(tokenized_document, kInvalidDocumentId,
519                       /*old_document_id=*/kInvalidDocumentId,
520                       /*recovery_mode=*/true, /*put_document_stats=*/nullptr),
521       StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
522   EXPECT_THAT(embedding_index_->last_added_document_id(),
523               Eq(kCurrentDocumentId));
524   // Check that the embedding index should be empty
525   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
526                                         /*model_signature=*/"model"),
527               IsOkAndHolds(IsEmpty()));
528   EXPECT_TRUE(embedding_index_->is_empty());
529   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()), IsEmpty());
530 }
531 
TEST_F(EmbeddingIndexingHandlerTest,HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError)532 TEST_F(EmbeddingIndexingHandlerTest,
533        HandleOutOfOrderDocumentIdShouldReturnInvalidArgumentError) {
534   DocumentProto document =
535       DocumentBuilder()
536           .SetKey("icing", "fake_type/1")
537           .SetSchema(std::string(kFakeType))
538           .AddStringProperty(std::string(kPropertyTitle), "title")
539           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
540                              CreateVector("model", {0.1, 0.2, 0.3}))
541           .AddStringProperty(std::string(kPropertyBody), "body")
542           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
543                              CreateVector("model", {0.4, 0.5, 0.6}),
544                              CreateVector("model", {0.7, 0.8, 0.9}))
545           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
546                              CreateVector("model", {1.1, 1.2, 1.3}))
547           .Build();
548   ICING_ASSERT_OK_AND_ASSIGN(
549       TokenizedDocument tokenized_document,
550       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
551                                 std::move(document)));
552   ICING_ASSERT_OK_AND_ASSIGN(
553       DocumentStore::PutResult put_result,
554       document_store_->Put(tokenized_document.document()));
555   DocumentId document_id = put_result.new_document_id;
556 
557   ICING_ASSERT_OK_AND_ASSIGN(
558       std::unique_ptr<EmbeddingIndexingHandler> handler,
559       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
560                                        /*enable_embedding_index=*/true));
561 
562   // Handling document with document_id == last_added_document_id should cause a
563   // failure, and both index data and last_added_document_id should remain
564   // unchanged.
565   embedding_index_->set_last_added_document_id(document_id);
566   ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
567   EXPECT_THAT(handler->Handle(
568                   tokenized_document, document_id, put_result.old_document_id,
569                   /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
570               StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
571   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id));
572 
573   // Check that the embedding index should be empty
574   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
575                                         /*model_signature=*/"model"),
576               IsOkAndHolds(IsEmpty()));
577   EXPECT_TRUE(embedding_index_->is_empty());
578   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()), IsEmpty());
579 
580   // Handling document with document_id < last_added_document_id should cause a
581   // failure, and both index data and last_added_document_id should remain
582   // unchanged.
583   embedding_index_->set_last_added_document_id(document_id + 1);
584   ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
585   EXPECT_THAT(handler->Handle(
586                   tokenized_document, document_id, put_result.old_document_id,
587                   /*recovery_mode=*/false, /*put_document_stats=*/nullptr),
588               StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
589   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id + 1));
590 
591   // Check that the embedding index should be empty
592   EXPECT_THAT(GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
593                                         /*model_signature=*/"model"),
594               IsOkAndHolds(IsEmpty()));
595   EXPECT_TRUE(embedding_index_->is_empty());
596   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()), IsEmpty());
597 }
598 
TEST_F(EmbeddingIndexingHandlerTest,HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId)599 TEST_F(EmbeddingIndexingHandlerTest,
600        HandleRecoveryModeShouldIgnoreDocsLELastAddedDocId) {
601   DocumentProto document1 =
602       DocumentBuilder()
603           .SetKey("icing", "fake_type/1")
604           .SetSchema(std::string(kFakeType))
605           .AddStringProperty(std::string(kPropertyTitle), "title one")
606           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
607                              CreateVector("model", {0.1, 0.2, 0.3}))
608           .AddStringProperty(std::string(kPropertyBody), "body one")
609           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
610                              CreateVector("model", {0.4, 0.5, 0.6}),
611                              CreateVector("model", {0.7, 0.8, 0.9}))
612           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
613                              CreateVector("model", {1.1, 1.2, 1.3}))
614           .Build();
615   DocumentProto document2 =
616       DocumentBuilder()
617           .SetKey("icing", "fake_type/2")
618           .SetSchema(std::string(kFakeType))
619           .AddStringProperty(std::string(kPropertyTitle), "title two")
620           .AddVectorProperty(std::string(kPropertyTitleEmbedding),
621                              CreateVector("model", {10.1, 10.2, 10.3}))
622           .AddStringProperty(std::string(kPropertyBody), "body two")
623           .AddVectorProperty(std::string(kPropertyBodyEmbedding),
624                              CreateVector("model", {10.4, 10.5, 10.6}),
625                              CreateVector("model", {10.7, 10.8, 10.9}))
626           .AddVectorProperty(std::string(kPropertyNonIndexableEmbedding),
627                              CreateVector("model", {11.1, 11.2, 11.3}))
628           .Build();
629   ICING_ASSERT_OK_AND_ASSIGN(
630       TokenizedDocument tokenized_document1,
631       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
632                                 std::move(document1)));
633   ICING_ASSERT_OK_AND_ASSIGN(
634       TokenizedDocument tokenized_document2,
635       TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
636                                 std::move(document2)));
637   ICING_ASSERT_OK_AND_ASSIGN(
638       DocumentStore::PutResult put_result1,
639       document_store_->Put(tokenized_document1.document()));
640   DocumentId document_id1 = put_result1.new_document_id;
641   ICING_ASSERT_OK_AND_ASSIGN(
642       DocumentStore::PutResult put_result2,
643       document_store_->Put(tokenized_document2.document()));
644   DocumentId document_id2 = put_result2.new_document_id;
645 
646   ICING_ASSERT_OK_AND_ASSIGN(
647       std::unique_ptr<EmbeddingIndexingHandler> handler,
648       EmbeddingIndexingHandler::Create(&fake_clock_, embedding_index_.get(),
649                                        /*enable_embedding_index=*/true));
650 
651   // Handle document with document_id > last_added_document_id in recovery mode.
652   // The handler should index this document and update last_added_document_id.
653   EXPECT_THAT(
654       handler->Handle(tokenized_document1, document_id1,
655                       put_result1.old_document_id, /*recovery_mode=*/true,
656                       /*put_document_stats=*/nullptr),
657       IsOk());
658   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id1));
659 
660   // Check index
661   EXPECT_THAT(
662       GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
663                                 /*model_signature=*/"model"),
664       IsOkAndHolds(ElementsAre(
665           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
666                        /*location=*/0),
667           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
668                        /*location=*/3),
669           EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
670                        /*location=*/6))));
671   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()),
672               ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
673 
674   // Handle document with document_id == last_added_document_id in recovery
675   // mode. We should not get any error, but the handler should ignore the
676   // document, so both index data and last_added_document_id should remain
677   // unchanged.
678   embedding_index_->set_last_added_document_id(document_id2);
679   ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
680   EXPECT_THAT(
681       handler->Handle(tokenized_document2, document_id2,
682                       put_result2.old_document_id, /*recovery_mode=*/true,
683                       /*put_document_stats=*/nullptr),
684       IsOk());
685   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2));
686 
687   // Check index
688   EXPECT_THAT(
689       GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
690                                 /*model_signature=*/"model"),
691       IsOkAndHolds(ElementsAre(
692           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
693                        /*location=*/0),
694           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
695                        /*location=*/3),
696           EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
697                        /*location=*/6))));
698   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()),
699               ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
700 
701   // Handle document with document_id < last_added_document_id in recovery mode.
702   // We should not get any error, but the handler should ignore the document, so
703   // both index data and last_added_document_id should remain unchanged.
704   embedding_index_->set_last_added_document_id(document_id2 + 1);
705   ASSERT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
706   EXPECT_THAT(
707       handler->Handle(tokenized_document2, document_id2,
708                       put_result2.old_document_id, /*recovery_mode=*/true,
709                       /*put_document_stats=*/nullptr),
710       IsOk());
711   EXPECT_THAT(embedding_index_->last_added_document_id(), Eq(document_id2 + 1));
712 
713   // Check index
714   EXPECT_THAT(
715       GetEmbeddingHitsFromIndex(embedding_index_.get(), /*dimension=*/3,
716                                 /*model_signature=*/"model"),
717       IsOkAndHolds(ElementsAre(
718           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
719                        /*location=*/0),
720           EmbeddingHit(BasicHit(kSectionIdBodyEmbedding, /*document_id=*/0),
721                        /*location=*/3),
722           EmbeddingHit(BasicHit(kSectionIdTitleEmbedding, /*document_id=*/0),
723                        /*location=*/6))));
724   EXPECT_THAT(GetRawEmbeddingDataFromIndex(embedding_index_.get()),
725               ElementsAre(0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3));
726 }
727 
728 }  // namespace lib
729 }  // namespace icing
730