xref: /aosp_15_r20/external/icing/icing/index/embed/posting-list-embedding-hit-serializer.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/index/embed/posting-list-embedding-hit-serializer.h"
16 
17 #include <cinttypes>
18 #include <cstdint>
19 #include <cstring>
20 #include <limits>
21 #include <vector>
22 
23 #include "icing/text_classifier/lib3/utils/base/status.h"
24 #include "icing/text_classifier/lib3/utils/base/statusor.h"
25 #include "icing/absl_ports/canonical_errors.h"
26 #include "icing/file/posting_list/posting-list-used.h"
27 #include "icing/index/embed/embedding-hit.h"
28 #include "icing/legacy/core/icing-string-util.h"
29 #include "icing/legacy/index/icing-bit-util.h"
30 #include "icing/util/logging.h"
31 #include "icing/util/status-macros.h"
32 
33 namespace icing {
34 namespace lib {
35 
GetBytesUsed(const PostingListUsed * posting_list_used) const36 uint32_t PostingListEmbeddingHitSerializer::GetBytesUsed(
37     const PostingListUsed* posting_list_used) const {
38   // The special hits will be included if they represent actual hits. If they
39   // represent the hit offset or the invalid hit sentinel, they are not
40   // included.
41   return posting_list_used->size_in_bytes() -
42          GetStartByteOffset(posting_list_used);
43 }
44 
GetMinPostingListSizeToFit(const PostingListUsed * posting_list_used) const45 uint32_t PostingListEmbeddingHitSerializer::GetMinPostingListSizeToFit(
46     const PostingListUsed* posting_list_used) const {
47   if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) {
48     // If in either the FULL state or ALMOST_FULL state, this posting list *is*
49     // the minimum size posting list that can fit these hits. So just return the
50     // size of the posting list.
51     return posting_list_used->size_in_bytes();
52   }
53 
54   // - In NOT_FULL status, BytesUsed contains no special hits. For a posting
55   //   list in the NOT_FULL state with n hits, we would have n-1 compressed hits
56   //   and 1 uncompressed hit.
57   // - The minimum sized posting list that would be guaranteed to fit these hits
58   //   would be FULL, but calculating the size required for the FULL posting
59   //   list would require deserializing the last two added hits, so instead we
60   //   will calculate the size of an ALMOST_FULL posting list to fit.
61   // - An ALMOST_FULL posting list would have kInvalidHit in special_hit(0), the
62   //   full uncompressed Hit in special_hit(1), and the n-1 compressed hits in
63   //   the compressed region.
64   // - Currently BytesUsed contains one uncompressed Hit and n-1 compressed
65   //   hits.
66   // - Therefore, fitting these hits into a posting list would require
67   //   BytesUsed + one extra full hit.
68   return GetBytesUsed(posting_list_used) + sizeof(EmbeddingHit);
69 }
70 
Clear(PostingListUsed * posting_list_used) const71 void PostingListEmbeddingHitSerializer::Clear(
72     PostingListUsed* posting_list_used) const {
73   // Safe to ignore return value because posting_list_used->size_in_bytes() is
74   // a valid argument.
75   SetStartByteOffset(posting_list_used,
76                      /*offset=*/posting_list_used->size_in_bytes());
77 }
78 
MoveFrom(PostingListUsed * dst,PostingListUsed * src) const79 libtextclassifier3::Status PostingListEmbeddingHitSerializer::MoveFrom(
80     PostingListUsed* dst, PostingListUsed* src) const {
81   ICING_RETURN_ERROR_IF_NULL(dst);
82   ICING_RETURN_ERROR_IF_NULL(src);
83   if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) {
84     return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
85         "src MinPostingListSizeToFit %d must be larger than size %d.",
86         GetMinPostingListSizeToFit(src), dst->size_in_bytes()));
87   }
88 
89   if (!IsPostingListValid(dst)) {
90     return absl_ports::FailedPreconditionError(
91         "Dst posting list is in an invalid state and can't be used!");
92   }
93   if (!IsPostingListValid(src)) {
94     return absl_ports::InvalidArgumentError(
95         "Cannot MoveFrom an invalid src posting list!");
96   }
97 
98   // Pop just enough hits that all of src's compressed hits fit in
99   // dst posting_list's compressed area. Then we can memcpy that area.
100   std::vector<EmbeddingHit> hits;
101   while (IsFull(src) || IsAlmostFull(src) ||
102          (dst->size_in_bytes() - kSpecialHitsSize < GetBytesUsed(src))) {
103     if (!GetHitsInternal(src, /*limit=*/1, /*pop=*/true, &hits).ok()) {
104       return absl_ports::AbortedError(
105           "Unable to retrieve hits from src posting list.");
106     }
107   }
108 
109   // memcpy the area and set up start byte offset.
110   Clear(dst);
111   memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src),
112          src->posting_list_buffer() + GetStartByteOffset(src),
113          GetBytesUsed(src));
114   // Because we popped all hits from src outside of the compressed area and we
115   // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() -
116   // kSpecialHitSize. This is guaranteed to be a valid byte offset for the
117   // NOT_FULL state, so ignoring the value is safe.
118   SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src));
119 
120   // Put back remaining hits.
121   for (size_t i = 0; i < hits.size(); i++) {
122     const EmbeddingHit& hit = hits[hits.size() - i - 1];
123     // PrependHit can return either INVALID_ARGUMENT - if hit is invalid or not
124     // less than the previous hit - or RESOURCE_EXHAUSTED. RESOURCE_EXHAUSTED
125     // should be impossible because we've already assured that there is enough
126     // room above.
127     ICING_RETURN_IF_ERROR(PrependHit(dst, hit));
128   }
129 
130   Clear(src);
131   return libtextclassifier3::Status::OK;
132 }
133 
GetPadEnd(const PostingListUsed * posting_list_used,uint32_t offset) const134 uint32_t PostingListEmbeddingHitSerializer::GetPadEnd(
135     const PostingListUsed* posting_list_used, uint32_t offset) const {
136   EmbeddingHit::Value pad;
137   uint32_t pad_end = offset;
138   while (pad_end < posting_list_used->size_in_bytes()) {
139     size_t pad_len = VarInt::Decode(
140         posting_list_used->posting_list_buffer() + pad_end, &pad);
141     if (pad != 0) {
142       // No longer a pad.
143       break;
144     }
145     pad_end += pad_len;
146   }
147   return pad_end;
148 }
149 
PadToEnd(PostingListUsed * posting_list_used,uint32_t start,uint32_t end) const150 bool PostingListEmbeddingHitSerializer::PadToEnd(
151     PostingListUsed* posting_list_used, uint32_t start, uint32_t end) const {
152   if (end > posting_list_used->size_in_bytes()) {
153     ICING_LOG(ERROR) << "Cannot pad a region that ends after size!";
154     return false;
155   }
156   // In VarInt a value of 0 encodes to 0.
157   memset(posting_list_used->posting_list_buffer() + start, 0, end - start);
158   return true;
159 }
160 
161 libtextclassifier3::Status
PrependHitToAlmostFull(PostingListUsed * posting_list_used,const EmbeddingHit & hit) const162 PostingListEmbeddingHitSerializer::PrependHitToAlmostFull(
163     PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
164   // Get delta between first hit and the new hit. Try to fit delta
165   // in the padded area and put new hit at the special position 1.
166   // Calling ValueOrDie is safe here because 1 < kNumSpecialData.
167   EmbeddingHit cur = GetSpecialHit(posting_list_used, /*index=*/1);
168   if (cur.value() <= hit.value()) {
169     return absl_ports::InvalidArgumentError(
170         "Hit being prepended must be strictly less than the most recent Hit");
171   }
172   uint64_t delta = cur.value() - hit.value();
173   uint8_t delta_buf[VarInt::kMaxEncodedLen64];
174   size_t delta_len = VarInt::Encode(delta, delta_buf);
175 
176   uint32_t pad_end = GetPadEnd(posting_list_used,
177                                /*offset=*/kSpecialHitsSize);
178 
179   if (pad_end >= kSpecialHitsSize + delta_len) {
180     // Pad area has enough space for delta of existing hit (cur). Write delta at
181     // pad_end - delta_len.
182     uint8_t* delta_offset =
183         posting_list_used->posting_list_buffer() + pad_end - delta_len;
184     memcpy(delta_offset, delta_buf, delta_len);
185 
186     // Now first hit is the new hit, at special position 1. Safe to ignore the
187     // return value because 1 < kNumSpecialData.
188     SetSpecialHit(posting_list_used, /*index=*/1, hit);
189     // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
190     // argument.
191     SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
192   } else {
193     // No space for delta. We put the new hit at special position 0
194     // and go to the full state. Safe to ignore the return value because 1 <
195     // kNumSpecialData.
196     SetSpecialHit(posting_list_used, /*index=*/0, hit);
197   }
198   return libtextclassifier3::Status::OK;
199 }
200 
PrependHitToEmpty(PostingListUsed * posting_list_used,const EmbeddingHit & hit) const201 void PostingListEmbeddingHitSerializer::PrependHitToEmpty(
202     PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
203   // First hit to be added. Just add verbatim, no compression.
204   if (posting_list_used->size_in_bytes() == kSpecialHitsSize) {
205     // Safe to ignore the return value because 1 < kNumSpecialData
206     SetSpecialHit(posting_list_used, /*index=*/1, hit);
207     // Safe to ignore the return value because sizeof(EmbeddingHit) is a valid
208     // argument.
209     SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
210   } else {
211     // Since this is the first hit, size != kSpecialHitsSize and
212     // size % sizeof(EmbeddingHit) == 0, we know that there is room to fit 'hit'
213     // into the compressed region, so ValueOrDie is safe.
214     uint32_t offset =
215         PrependHitUncompressed(posting_list_used, hit,
216                                /*offset=*/posting_list_used->size_in_bytes())
217             .ValueOrDie();
218     // Safe to ignore the return value because PrependHitUncompressed is
219     // guaranteed to return a valid offset.
220     SetStartByteOffset(posting_list_used, offset);
221   }
222 }
223 
224 libtextclassifier3::Status
PrependHitToNotFull(PostingListUsed * posting_list_used,const EmbeddingHit & hit,uint32_t offset) const225 PostingListEmbeddingHitSerializer::PrependHitToNotFull(
226     PostingListUsed* posting_list_used, const EmbeddingHit& hit,
227     uint32_t offset) const {
228   // First hit in compressed area. It is uncompressed. See if delta
229   // between the first hit and new hit will still fit in the
230   // compressed area.
231   if (offset + sizeof(EmbeddingHit::Value) >
232       posting_list_used->size_in_bytes()) {
233     // The first hit in the compressed region *should* be uncompressed, but
234     // somehow there isn't enough room between offset and the end of the
235     // compressed area to fit an uncompressed hit. This should NEVER happen.
236     return absl_ports::FailedPreconditionError(
237         "Posting list is in an invalid state.");
238   }
239   EmbeddingHit::Value cur_value;
240   memcpy(&cur_value, posting_list_used->posting_list_buffer() + offset,
241          sizeof(EmbeddingHit::Value));
242   if (cur_value <= hit.value()) {
243     return absl_ports::InvalidArgumentError(
244         IcingStringUtil::StringPrintf("EmbeddingHit %" PRId64
245                                       " being prepended must be "
246                                       "strictly less than the most recent "
247                                       "EmbeddingHit %" PRId64,
248                                       hit.value(), cur_value));
249   }
250   uint64_t delta = cur_value - hit.value();
251   uint8_t delta_buf[VarInt::kMaxEncodedLen64];
252   size_t delta_len = VarInt::Encode(delta, delta_buf);
253 
254   // offset now points to one past the end of the first hit.
255   offset += sizeof(EmbeddingHit::Value);
256   if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) + delta_len <= offset) {
257     // Enough space for delta in compressed area.
258 
259     // Prepend delta.
260     offset -= delta_len;
261     memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
262            delta_len);
263 
264     // Prepend new hit. We know that there is room for 'hit' because of the if
265     // statement above, so calling ValueOrDie is safe.
266     offset =
267         PrependHitUncompressed(posting_list_used, hit, offset).ValueOrDie();
268     // offset is guaranteed to be valid here. So it's safe to ignore the return
269     // value. The if above will guarantee that offset >= kSpecialHitSize and <
270     // posting_list_used->size_in_bytes() because the if ensures that there is
271     // enough room between offset and kSpecialHitSize to fit the delta of the
272     // previous hit and the uncompressed hit.
273     SetStartByteOffset(posting_list_used, offset);
274   } else if (kSpecialHitsSize + delta_len <= offset) {
275     // Only have space for delta. The new hit must be put in special
276     // position 1.
277 
278     // Prepend delta.
279     offset -= delta_len;
280     memcpy(posting_list_used->posting_list_buffer() + offset, delta_buf,
281            delta_len);
282 
283     // Prepend pad. Safe to ignore the return value of PadToEnd because offset
284     // must be less than posting_list_used->size_in_bytes(). Otherwise, this
285     // function already would have returned FAILED_PRECONDITION.
286     PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
287              /*end=*/offset);
288 
289     // Put new hit in special position 1. Safe to ignore return value because 1
290     // < kNumSpecialData.
291     SetSpecialHit(posting_list_used, /*index=*/1, hit);
292 
293     // State almost_full. Safe to ignore the return value because
294     // sizeof(EmbeddingHit) is a valid argument.
295     SetStartByteOffset(posting_list_used, /*offset=*/sizeof(EmbeddingHit));
296   } else {
297     // Very rare case where delta is larger than sizeof(EmbeddingHit::Value)
298     // (i.e. varint delta encoding expanded required storage). We
299     // move first hit to special position 1 and put new hit in
300     // special position 0.
301     EmbeddingHit cur(cur_value);
302     // Safe to ignore the return value of PadToEnd because offset must be less
303     // than posting_list_used->size_in_bytes(). Otherwise, this function
304     // already would have returned FAILED_PRECONDITION.
305     PadToEnd(posting_list_used, /*start=*/kSpecialHitsSize,
306              /*end=*/offset);
307     // Safe to ignore the return value here because 0 and 1 < kNumSpecialData.
308     SetSpecialHit(posting_list_used, /*index=*/1, cur);
309     SetSpecialHit(posting_list_used, /*index=*/0, hit);
310   }
311   return libtextclassifier3::Status::OK;
312 }
313 
PrependHit(PostingListUsed * posting_list_used,const EmbeddingHit & hit) const314 libtextclassifier3::Status PostingListEmbeddingHitSerializer::PrependHit(
315     PostingListUsed* posting_list_used, const EmbeddingHit& hit) const {
316   static_assert(
317       sizeof(EmbeddingHit::Value) <= sizeof(uint64_t),
318       "EmbeddingHit::Value cannot be larger than 8 bytes because the delta "
319       "must be able to fit in 8 bytes.");
320   if (!hit.is_valid()) {
321     return absl_ports::InvalidArgumentError("Cannot prepend an invalid hit!");
322   }
323   if (!IsPostingListValid(posting_list_used)) {
324     return absl_ports::FailedPreconditionError(
325         "This PostingListUsed is in an invalid state and can't add any hits!");
326   }
327 
328   if (IsFull(posting_list_used)) {
329     // State full: no space left.
330     return absl_ports::ResourceExhaustedError("No more room for hits");
331   } else if (IsAlmostFull(posting_list_used)) {
332     return PrependHitToAlmostFull(posting_list_used, hit);
333   } else if (IsEmpty(posting_list_used)) {
334     PrependHitToEmpty(posting_list_used, hit);
335     return libtextclassifier3::Status::OK;
336   } else {
337     uint32_t offset = GetStartByteOffset(posting_list_used);
338     return PrependHitToNotFull(posting_list_used, hit, offset);
339   }
340 }
341 
342 libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
GetHits(const PostingListUsed * posting_list_used) const343 PostingListEmbeddingHitSerializer::GetHits(
344     const PostingListUsed* posting_list_used) const {
345   std::vector<EmbeddingHit> hits_out;
346   ICING_RETURN_IF_ERROR(GetHits(posting_list_used, &hits_out));
347   return hits_out;
348 }
349 
GetHits(const PostingListUsed * posting_list_used,std::vector<EmbeddingHit> * hits_out) const350 libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHits(
351     const PostingListUsed* posting_list_used,
352     std::vector<EmbeddingHit>* hits_out) const {
353   return GetHitsInternal(posting_list_used,
354                          /*limit=*/std::numeric_limits<uint32_t>::max(),
355                          /*pop=*/false, hits_out);
356 }
357 
PopFrontHits(PostingListUsed * posting_list_used,uint32_t num_hits) const358 libtextclassifier3::Status PostingListEmbeddingHitSerializer::PopFrontHits(
359     PostingListUsed* posting_list_used, uint32_t num_hits) const {
360   if (num_hits == 1 && IsFull(posting_list_used)) {
361     // The PL is in full status which means that we save 2 uncompressed hits in
362     // the 2 special postions. But full status may be reached by 2 different
363     // statuses.
364     // (1) In "almost full" status
365     // +-----------------+----------------+-------+-----------------+
366     // |Hit::kInvalidVal |1st hit         |(pad)  |(compressed) hits|
367     // +-----------------+----------------+-------+-----------------+
368     // When we prepend another hit, we can only put it at the special
369     // position 0. And we get a full PL
370     // +-----------------+----------------+-------+-----------------+
371     // |new 1st hit      |original 1st hit|(pad)  |(compressed) hits|
372     // +-----------------+----------------+-------+-----------------+
373     // (2) In "not full" status
374     // +-----------------+----------------+------+-------+------------------+
375     // |hits-start-offset|Hit::kInvalidVal|(pad) |1st hit|(compressed) hits |
376     // +-----------------+----------------+------+-------+------------------+
377     // When we prepend another hit, we can reach any of the 3 following
378     // scenarios:
379     // (2.1) not full
380     // if the space of pad and original 1st hit can accommodate the new 1st hit
381     // and the encoded delta value.
382     // +-----------------+----------------+------+-----------+-----------------+
383     // |hits-start-offset|Hit::kInvalidVal|(pad) |new 1st hit|(compressed) hits|
384     // +-----------------+----------------+------+-----------+-----------------+
385     // (2.2) almost full
386     // If the space of pad and original 1st hit cannot accommodate the new 1st
387     // hit and the encoded delta value but can accommodate the encoded delta
388     // value only. We can put the new 1st hit at special position 1.
389     // +-----------------+----------------+-------+-----------------+
390     // |Hit::kInvalidVal |new 1st hit     |(pad)  |(compressed) hits|
391     // +-----------------+----------------+-------+-----------------+
392     // (2.3) full
393     // In very rare case, it cannot even accommodate only the encoded delta
394     // value. we can move the original 1st hit into special position 1 and the
395     // new 1st hit into special position 0. This may happen because we use
396     // VarInt encoding method which may make the encoded value longer (about
397     // 4/3 times of original)
398     // +-----------------+----------------+-------+-----------------+
399     // |new 1st hit      |original 1st hit|(pad)  |(compressed) hits|
400     // +-----------------+----------------+-------+-----------------+
401     // Suppose now the PL is full. But we don't know whether it arrived to
402     // this status from "not full" like (2.3) or from "almost full" like (1).
403     // We'll return to "almost full" status like (1) if we simply pop the new
404     // 1st hit but we want to make the prepending operation "reversible". So
405     // there should be some way to return to "not full" if possible. A simple
406     // way to do it is to pop 2 hits out of the PL to status "almost full" or
407     // "not full".  And add the original 1st hit back. We can return to the
408     // correct original statuses of (2.1) or (1). This makes our prepending
409     // operation reversible.
410     std::vector<EmbeddingHit> out;
411 
412     // Popping 2 hits should never fail because we've just ensured that the
413     // posting list is in the FULL state.
414     ICING_RETURN_IF_ERROR(
415         GetHitsInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out));
416 
417     // PrependHit should never fail because out[1] is a valid hit less than
418     // previous hits in the posting list and because there's no way that the
419     // posting list could run out of room because it previously stored this hit
420     // AND another hit.
421     ICING_RETURN_IF_ERROR(PrependHit(posting_list_used, out[1]));
422   } else if (num_hits > 0) {
423     return GetHitsInternal(posting_list_used, /*limit=*/num_hits, /*pop=*/true,
424                            nullptr);
425   }
426   return libtextclassifier3::Status::OK;
427 }
428 
GetHitsInternal(const PostingListUsed * posting_list_used,uint32_t limit,bool pop,std::vector<EmbeddingHit> * out) const429 libtextclassifier3::Status PostingListEmbeddingHitSerializer::GetHitsInternal(
430     const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
431     std::vector<EmbeddingHit>* out) const {
432   // Put current uncompressed val here.
433   EmbeddingHit::Value val = EmbeddingHit::kInvalidValue;
434   uint32_t offset = GetStartByteOffset(posting_list_used);
435   uint32_t count = 0;
436 
437   // First traverse the first two special positions.
438   while (count < limit && offset < kSpecialHitsSize) {
439     // Calling ValueOrDie is safe here because offset / sizeof(EmbeddingHit) <
440     // kNumSpecialData because of the check above.
441     EmbeddingHit hit = GetSpecialHit(posting_list_used,
442                                      /*index=*/offset / sizeof(EmbeddingHit));
443     val = hit.value();
444     if (out != nullptr) {
445       out->push_back(hit);
446     }
447     offset += sizeof(EmbeddingHit);
448     count++;
449   }
450 
451   // If special position 1 was set then we need to skip padding.
452   if (val != EmbeddingHit::kInvalidValue && offset == kSpecialHitsSize) {
453     offset = GetPadEnd(posting_list_used, offset);
454   }
455 
456   while (count < limit && offset < posting_list_used->size_in_bytes()) {
457     if (val == EmbeddingHit::kInvalidValue) {
458       // First hit is in compressed area. Put that in val.
459       memcpy(&val, posting_list_used->posting_list_buffer() + offset,
460              sizeof(EmbeddingHit::Value));
461       offset += sizeof(EmbeddingHit::Value);
462     } else {
463       // Now we have delta encoded subsequent hits. Decode and push.
464       uint64_t delta;
465       offset += VarInt::Decode(
466           posting_list_used->posting_list_buffer() + offset, &delta);
467       val += delta;
468     }
469     EmbeddingHit hit(val);
470     if (out != nullptr) {
471       out->push_back(hit);
472     }
473     count++;
474   }
475 
476   if (pop) {
477     PostingListUsed* mutable_posting_list_used =
478         const_cast<PostingListUsed*>(posting_list_used);
479     // Modify the posting list so that we pop all hits actually
480     // traversed.
481     if (offset >= kSpecialHitsSize &&
482         offset < posting_list_used->size_in_bytes()) {
483       // In the compressed area. Pop and reconstruct. offset/val is
484       // the last traversed hit, which we must discard. So move one
485       // more forward.
486       uint64_t delta;
487       offset += VarInt::Decode(
488           posting_list_used->posting_list_buffer() + offset, &delta);
489       val += delta;
490 
491       // Now val is the first hit of the new posting list.
492       if (kSpecialHitsSize + sizeof(EmbeddingHit::Value) <= offset) {
493         // val fits in compressed area. Simply copy.
494         offset -= sizeof(EmbeddingHit::Value);
495         memcpy(mutable_posting_list_used->posting_list_buffer() + offset, &val,
496                sizeof(EmbeddingHit::Value));
497       } else {
498         // val won't fit in compressed area.
499         EmbeddingHit hit(val);
500         // Okay to ignore the return value here because 1 < kNumSpecialData.
501         SetSpecialHit(mutable_posting_list_used, /*index=*/1, hit);
502 
503         // Prepend pad. Safe to ignore the return value of PadToEnd because
504         // offset must be less than posting_list_used->size_in_bytes() thanks to
505         // the if above.
506         PadToEnd(mutable_posting_list_used,
507                  /*start=*/kSpecialHitsSize,
508                  /*end=*/offset);
509         offset = sizeof(EmbeddingHit);
510       }
511     }
512     // offset is guaranteed to be valid so ignoring the return value of
513     // set_start_byte_offset is safe. It falls into one of four scenarios:
514     // Scenario 1: the above if was false because offset is not <
515     //             posting_list_used->size_in_bytes()
516     //   In this case, offset must be == posting_list_used->size_in_bytes()
517     //   because we reached offset by unwinding hits on the posting list.
518     // Scenario 2: offset is < kSpecialHitSize
519     //   In this case, offset is guaranteed to be either 0 or
520     //   sizeof(EmbeddingHit) because offset is incremented by
521     //   sizeof(EmbeddingHit) within the first while loop.
522     // Scenario 3: offset is within the compressed region and the new first hit
523     //   in the posting list (the value that 'val' holds) will fit as an
524     //   uncompressed hit in the compressed region. The resulting offset from
525     //   decompressing val must be >= kSpecialHitSize because otherwise we'd be
526     //   in Scenario 4
527     // Scenario 4: offset is within the compressed region, but the new first hit
528     //   in the posting list is too large to fit as an uncompressed hit in the
529     //   in the compressed region. Therefore, it must be stored in a special hit
530     //   and offset will be sizeof(EmbeddingHit).
531     SetStartByteOffset(mutable_posting_list_used, offset);
532   }
533 
534   return libtextclassifier3::Status::OK;
535 }
536 
GetSpecialHit(const PostingListUsed * posting_list_used,uint32_t index) const537 EmbeddingHit PostingListEmbeddingHitSerializer::GetSpecialHit(
538     const PostingListUsed* posting_list_used, uint32_t index) const {
539   static_assert(sizeof(EmbeddingHit::Value) >= sizeof(uint32_t), "HitTooSmall");
540   EmbeddingHit val(EmbeddingHit::kInvalidValue);
541   memcpy(&val, posting_list_used->posting_list_buffer() + index * sizeof(val),
542          sizeof(val));
543   return val;
544 }
545 
SetSpecialHit(PostingListUsed * posting_list_used,uint32_t index,const EmbeddingHit & val) const546 void PostingListEmbeddingHitSerializer::SetSpecialHit(
547     PostingListUsed* posting_list_used, uint32_t index,
548     const EmbeddingHit& val) const {
549   memcpy(posting_list_used->posting_list_buffer() + index * sizeof(val), &val,
550          sizeof(val));
551 }
552 
IsPostingListValid(const PostingListUsed * posting_list_used) const553 bool PostingListEmbeddingHitSerializer::IsPostingListValid(
554     const PostingListUsed* posting_list_used) const {
555   if (IsAlmostFull(posting_list_used)) {
556     // Special Hit 1 should hold a Hit. Calling ValueOrDie is safe because we
557     // know that 1 < kNumSpecialData.
558     if (!GetSpecialHit(posting_list_used, /*index=*/1).is_valid()) {
559       ICING_LOG(ERROR)
560           << "Both special hits cannot be invalid at the same time.";
561       return false;
562     }
563   } else if (!IsFull(posting_list_used)) {
564     // NOT_FULL. Special Hit 0 should hold a valid offset. Calling ValueOrDie is
565     // safe because we know that 0 < kNumSpecialData.
566     if (GetSpecialHit(posting_list_used, /*index=*/0).value() >
567             posting_list_used->size_in_bytes() ||
568         GetSpecialHit(posting_list_used, /*index=*/0).value() <
569             kSpecialHitsSize) {
570       ICING_LOG(ERROR) << "EmbeddingHit: "
571                        << GetSpecialHit(posting_list_used, /*index=*/0).value()
572                        << " size: " << posting_list_used->size_in_bytes()
573                        << " sp size: " << kSpecialHitsSize;
574       return false;
575     }
576   }
577   return true;
578 }
579 
GetStartByteOffset(const PostingListUsed * posting_list_used) const580 uint32_t PostingListEmbeddingHitSerializer::GetStartByteOffset(
581     const PostingListUsed* posting_list_used) const {
582   if (IsFull(posting_list_used)) {
583     return 0;
584   } else if (IsAlmostFull(posting_list_used)) {
585     return sizeof(EmbeddingHit);
586   } else {
587     // NOT_FULL, calling ValueOrDie is safe because we know that 0 <
588     // kNumSpecialData.
589     return GetSpecialHit(posting_list_used, /*index=*/0).value();
590   }
591 }
592 
SetStartByteOffset(PostingListUsed * posting_list_used,uint32_t offset) const593 bool PostingListEmbeddingHitSerializer::SetStartByteOffset(
594     PostingListUsed* posting_list_used, uint32_t offset) const {
595   if (offset > posting_list_used->size_in_bytes()) {
596     ICING_LOG(ERROR) << "offset cannot be a value greater than size "
597                      << posting_list_used->size_in_bytes() << ". offset is "
598                      << offset << ".";
599     return false;
600   }
601   if (offset < kSpecialHitsSize && offset > sizeof(EmbeddingHit)) {
602     ICING_LOG(ERROR) << "offset cannot be a value between ("
603                      << sizeof(EmbeddingHit) << ", " << kSpecialHitsSize
604                      << "). offset is " << offset << ".";
605     return false;
606   }
607   if (offset < sizeof(EmbeddingHit) && offset != 0) {
608     ICING_LOG(ERROR) << "offset cannot be a value between (0, "
609                      << sizeof(EmbeddingHit) << "). offset is " << offset
610                      << ".";
611     return false;
612   }
613   if (offset >= kSpecialHitsSize) {
614     // not_full state. Safe to ignore the return value because 0 and 1 are both
615     // < kNumSpecialData.
616     SetSpecialHit(posting_list_used, /*index=*/0, EmbeddingHit(offset));
617     SetSpecialHit(posting_list_used, /*index=*/1,
618                   EmbeddingHit(EmbeddingHit::kInvalidValue));
619   } else if (offset == sizeof(EmbeddingHit)) {
620     // almost_full state. Safe to ignore the return value because 1 is both <
621     // kNumSpecialData.
622     SetSpecialHit(posting_list_used, /*index=*/0,
623                   EmbeddingHit(EmbeddingHit::kInvalidValue));
624   }
625   // Nothing to do for the FULL state - the offset isn't actually stored
626   // anywhere and both special hits hold valid hits.
627   return true;
628 }
629 
630 libtextclassifier3::StatusOr<uint32_t>
PrependHitUncompressed(PostingListUsed * posting_list_used,const EmbeddingHit & hit,uint32_t offset) const631 PostingListEmbeddingHitSerializer::PrependHitUncompressed(
632     PostingListUsed* posting_list_used, const EmbeddingHit& hit,
633     uint32_t offset) const {
634   if (offset < kSpecialHitsSize + sizeof(EmbeddingHit::Value)) {
635     return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
636         "Not enough room to prepend EmbeddingHit::Value at offset %d.",
637         offset));
638   }
639   offset -= sizeof(EmbeddingHit::Value);
640   EmbeddingHit::Value val = hit.value();
641   memcpy(posting_list_used->posting_list_buffer() + offset, &val,
642          sizeof(EmbeddingHit::Value));
643   return offset;
644 }
645 
646 }  // namespace lib
647 }  // namespace icing
648