1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/index/embed/embedding-hit.h"
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <vector>
20
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "icing/index/hit/hit.h"
24 #include "icing/schema/section.h"
25 #include "icing/store/document-id.h"
26
27 namespace icing {
28 namespace lib {
29
30 namespace {
31
32 using ::testing::ElementsAre;
33 using ::testing::Eq;
34 using ::testing::IsFalse;
35
36 static constexpr DocumentId kSomeDocumentId = 24;
37 static constexpr SectionId kSomeSectionid = 5;
38 static constexpr uint32_t kSomeLocation = 123;
39
TEST(EmbeddingHitTest,Accessors)40 TEST(EmbeddingHitTest, Accessors) {
41 BasicHit basic_hit(kSomeSectionid, kSomeDocumentId);
42 EmbeddingHit embedding_hit(basic_hit, kSomeLocation);
43 EXPECT_THAT(embedding_hit.basic_hit(), Eq(basic_hit));
44 EXPECT_THAT(embedding_hit.location(), Eq(kSomeLocation));
45 }
46
TEST(EmbeddingHitTest,Invalid)47 TEST(EmbeddingHitTest, Invalid) {
48 EmbeddingHit invalid_hit(EmbeddingHit::kInvalidValue);
49 EXPECT_THAT(invalid_hit.is_valid(), IsFalse());
50
51 // Also make sure the invalid EmbeddingHit contains an invalid document id.
52 EXPECT_THAT(invalid_hit.basic_hit().document_id(), Eq(kInvalidDocumentId));
53 EXPECT_THAT(invalid_hit.basic_hit().section_id(), Eq(kMinSectionId));
54 EXPECT_THAT(invalid_hit.location(), Eq(0));
55 }
56
TEST(EmbeddingHitTest,Comparison)57 TEST(EmbeddingHitTest, Comparison) {
58 // Create basic hits with basic_hit1 < basic_hit2 < basic_hit3.
59 BasicHit basic_hit1(/*section_id=*/1, /*document_id=*/2409);
60 BasicHit basic_hit2(/*section_id=*/1, /*document_id=*/243);
61 BasicHit basic_hit3(/*section_id=*/15, /*document_id=*/243);
62
63 // Embedding hits are sorted by BasicHit first, and then by location.
64 // So embedding_hit3 < embedding_hit4 < embedding_hit2 < embedding_hit1.
65 EmbeddingHit embedding_hit1(basic_hit3, /*location=*/10);
66 EmbeddingHit embedding_hit2(basic_hit3, /*location=*/0);
67 EmbeddingHit embedding_hit3(basic_hit1, /*location=*/100);
68 EmbeddingHit embedding_hit4(basic_hit2, /*location=*/0);
69
70 std::vector<EmbeddingHit> embedding_hits{embedding_hit1, embedding_hit2,
71 embedding_hit3, embedding_hit4};
72 std::sort(embedding_hits.begin(), embedding_hits.end());
73 EXPECT_THAT(embedding_hits, ElementsAre(embedding_hit3, embedding_hit4,
74 embedding_hit2, embedding_hit1));
75 }
76
77 } // namespace
78
79 } // namespace lib
80 } // namespace icing
81