xref: /aosp_15_r20/external/grpc-grpc/third_party/utf8_range/utf8_range.c (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2023 Google LLC
2 //
3 // Use of this source code is governed by an MIT-style
4 // license that can be found in the LICENSE file or at
5 // https://opensource.org/licenses/MIT.
6 
7 /* This is a wrapper for the Google range-sse.cc algorithm which checks whether
8  * a sequence of bytes is a valid UTF-8 sequence and finds the longest valid
9  * prefix of the UTF-8 sequence.
10  *
11  * The key difference is that it checks for as much ASCII symbols as possible
12  * and then falls back to the range-sse.cc algorithm. The changes to the
13  * algorithm are cosmetic, mostly to trick the clang compiler to produce optimal
14  * code.
15  *
16  * For API see the utf8_validity.h header.
17  */
18 #include "utf8_range.h"
19 
20 #include <stddef.h>
21 #include <stdint.h>
22 #include <string.h>
23 
24 #ifdef __SSE4_1__
25 #include <emmintrin.h>
26 #include <smmintrin.h>
27 #include <tmmintrin.h>
28 #endif
29 
30 #if defined(__GNUC__)
31 #define FORCE_INLINE_ATTR __attribute__((always_inline))
32 #elif defined(_MSC_VER)
33 #define FORCE_INLINE_ATTR __forceinline
34 #else
35 #define FORCE_INLINE_ATTR
36 #endif
37 
utf8_range_UnalignedLoad64(const void * p)38 static FORCE_INLINE_ATTR inline uint64_t utf8_range_UnalignedLoad64(
39     const void* p) {
40   uint64_t t;
41   memcpy(&t, p, sizeof t);
42   return t;
43 }
44 
utf8_range_AsciiIsAscii(unsigned char c)45 static FORCE_INLINE_ATTR inline int utf8_range_AsciiIsAscii(unsigned char c) {
46   return c < 128;
47 }
48 
utf8_range_IsTrailByteOk(const char c)49 static FORCE_INLINE_ATTR inline int utf8_range_IsTrailByteOk(const char c) {
50   return (int8_t)(c) <= (int8_t)(0xBF);
51 }
52 
53 /* If return_position is false then it returns 1 if |data| is a valid utf8
54  * sequence, otherwise returns 0.
55  * If return_position is set to true, returns the length in bytes of the prefix
56    of |data| that is all structurally valid UTF-8.
57  */
utf8_range_ValidateUTF8Naive(const char * data,const char * end,int return_position)58 static size_t utf8_range_ValidateUTF8Naive(const char* data, const char* end,
59                                            int return_position) {
60   /* We return err_pos in the loop which is always 0 if !return_position */
61   size_t err_pos = 0;
62   size_t codepoint_bytes = 0;
63   /* The early check is done because of early continue's on codepoints of all
64    * sizes, i.e. we first check for ascii and if it is, we call continue, then
65    * for 2 byte codepoints, etc. This is done in order to reduce indentation and
66    * improve readability of the codepoint validity check.
67    */
68   while (data + codepoint_bytes < end) {
69     if (return_position) {
70       err_pos += codepoint_bytes;
71     }
72     data += codepoint_bytes;
73     const size_t len = end - data;
74     const unsigned char byte1 = data[0];
75 
76     /* We do not skip many ascii bytes at the same time as this function is
77        used for tail checking (< 16 bytes) and for non x86 platforms. We also
78        don't think that cases where non-ASCII codepoints are followed by ascii
79        happen often. For small strings it also introduces some penalty. For
80        purely ascii UTF8 strings (which is the overwhelming case) we call
81        SkipAscii function which is multiplatform and extremely fast.
82      */
83     /* [00..7F] ASCII -> 1 byte */
84     if (utf8_range_AsciiIsAscii(byte1)) {
85       codepoint_bytes = 1;
86       continue;
87     }
88     /* [C2..DF], [80..BF] -> 2 bytes */
89     if (len >= 2 && byte1 >= 0xC2 && byte1 <= 0xDF &&
90         utf8_range_IsTrailByteOk(data[1])) {
91       codepoint_bytes = 2;
92       continue;
93     }
94     if (len >= 3) {
95       const unsigned char byte2 = data[1];
96       const unsigned char byte3 = data[2];
97 
98       /* Is byte2, byte3 between [0x80, 0xBF]
99        * Check for 0x80 was done above.
100        */
101       if (!utf8_range_IsTrailByteOk(byte2) ||
102           !utf8_range_IsTrailByteOk(byte3)) {
103         return err_pos;
104       }
105 
106       if (/* E0, A0..BF, 80..BF */
107           ((byte1 == 0xE0 && byte2 >= 0xA0) ||
108            /* E1..EC, 80..BF, 80..BF */
109            (byte1 >= 0xE1 && byte1 <= 0xEC) ||
110            /* ED, 80..9F, 80..BF */
111            (byte1 == 0xED && byte2 <= 0x9F) ||
112            /* EE..EF, 80..BF, 80..BF */
113            (byte1 >= 0xEE && byte1 <= 0xEF))) {
114         codepoint_bytes = 3;
115         continue;
116       }
117       if (len >= 4) {
118         const unsigned char byte4 = data[3];
119         /* Is byte4 between 0x80 ~ 0xBF */
120         if (!utf8_range_IsTrailByteOk(byte4)) {
121           return err_pos;
122         }
123 
124         if (/* F0, 90..BF, 80..BF, 80..BF */
125             ((byte1 == 0xF0 && byte2 >= 0x90) ||
126              /* F1..F3, 80..BF, 80..BF, 80..BF */
127              (byte1 >= 0xF1 && byte1 <= 0xF3) ||
128              /* F4, 80..8F, 80..BF, 80..BF */
129              (byte1 == 0xF4 && byte2 <= 0x8F))) {
130           codepoint_bytes = 4;
131           continue;
132         }
133       }
134     }
135     return err_pos;
136   }
137   if (return_position) {
138     err_pos += codepoint_bytes;
139   }
140   /* if return_position is false, this returns 1.
141    * if return_position is true, this returns err_pos.
142    */
143   return err_pos + (1 - return_position);
144 }
145 
146 #ifdef __SSE4_1__
147 /* Returns the number of bytes needed to skip backwards to get to the first
148    byte of codepoint.
149  */
utf8_range_CodepointSkipBackwards(int32_t codepoint_word)150 static inline int utf8_range_CodepointSkipBackwards(int32_t codepoint_word) {
151   const int8_t* const codepoint = (const int8_t*)(&codepoint_word);
152   if (!utf8_range_IsTrailByteOk(codepoint[3])) {
153     return 1;
154   } else if (!utf8_range_IsTrailByteOk(codepoint[2])) {
155     return 2;
156   } else if (!utf8_range_IsTrailByteOk(codepoint[1])) {
157     return 3;
158   }
159   return 0;
160 }
161 #endif  // __SSE4_1__
162 
163 /* Skipping over ASCII as much as possible, per 8 bytes. It is intentional
164    as most strings to check for validity consist only of 1 byte codepoints.
165  */
utf8_range_SkipAscii(const char * data,const char * end)166 static inline const char* utf8_range_SkipAscii(const char* data,
167                                                const char* end) {
168   while (8 <= end - data &&
169          (utf8_range_UnalignedLoad64(data) & 0x8080808080808080) == 0) {
170     data += 8;
171   }
172   while (data < end && utf8_range_AsciiIsAscii(*data)) {
173     ++data;
174   }
175   return data;
176 }
177 
utf8_range_Validate(const char * data,size_t len,int return_position)178 static FORCE_INLINE_ATTR inline size_t utf8_range_Validate(
179     const char* data, size_t len, int return_position) {
180   if (len == 0) return 1 - return_position;
181   const char* const end = data + len;
182   data = utf8_range_SkipAscii(data, end);
183   /* SIMD algorithm always outperforms the naive version for any data of
184      length >=16.
185    */
186   if (end - data < 16) {
187     return (return_position ? (data - (end - len)) : 0) +
188            utf8_range_ValidateUTF8Naive(data, end, return_position);
189   }
190 #ifndef __SSE4_1__
191   return (return_position ? (data - (end - len)) : 0) +
192          utf8_range_ValidateUTF8Naive(data, end, return_position);
193 #else
194   /* This code checks that utf-8 ranges are structurally valid 16 bytes at once
195    * using superscalar instructions.
196    * The mapping between ranges of codepoint and their corresponding utf-8
197    * sequences is below.
198    */
199 
200   /*
201    * U+0000...U+007F     00...7F
202    * U+0080...U+07FF     C2...DF 80...BF
203    * U+0800...U+0FFF     E0      A0...BF 80...BF
204    * U+1000...U+CFFF     E1...EC 80...BF 80...BF
205    * U+D000...U+D7FF     ED      80...9F 80...BF
206    * U+E000...U+FFFF     EE...EF 80...BF 80...BF
207    * U+10000...U+3FFFF   F0      90...BF 80...BF 80...BF
208    * U+40000...U+FFFFF   F1...F3 80...BF 80...BF 80...BF
209    * U+100000...U+10FFFF F4      80...8F 80...BF 80...BF
210    */
211 
212   /* First we compute the type for each byte, as given by the table below.
213    * This type will be used as an index later on.
214    */
215 
216   /*
217    * Index  Min Max Byte Type
218    *  0     00  7F  Single byte sequence
219    *  1,2,3 80  BF  Second, third and fourth byte for many of the sequences.
220    *  4     A0  BF  Second byte after E0
221    *  5     80  9F  Second byte after ED
222    *  6     90  BF  Second byte after F0
223    *  7     80  8F  Second byte after F4
224    *  8     C2  F4  First non ASCII byte
225    *  9..15 7F  80  Invalid byte
226    */
227 
228   /* After the first step we compute the index for all bytes, then we permute
229      the bytes according to their indices to check the ranges from the range
230      table.
231    * The range for a given type can be found in the range_min_table and
232      range_max_table, the range for type/index X is in range_min_table[X] ...
233      range_max_table[X].
234    */
235 
236   /* Algorithm:
237    * Put index zero to all bytes.
238    * Find all non ASCII characters, give them index 8.
239    * For each tail byte in a codepoint sequence, give it an index corresponding
240      to the 1 based index from the end.
241    * If the first byte of the codepoint is in the [C0...DF] range, we write
242      index 1 in the following byte.
243    * If the first byte of the codepoint is in the range [E0...EF], we write
244      indices 2 and 1 in the next two bytes.
245    * If the first byte of the codepoint is in the range [F0...FF] we write
246      indices 3,2,1 into the next three bytes.
247    * For finding the number of bytes we need to look at high nibbles (4 bits)
248      and do the lookup from the table, it can be done with shift by 4 + shuffle
249      instructions. We call it `first_len`.
250    * Then we shift first_len by 8 bits to get the indices of the 2nd bytes.
251    * Saturating sub 1 and shift by 8 bits to get the indices of the 3rd bytes.
252    * Again to get the indices of the 4th bytes.
253    * Take OR of all that 4 values and check within range.
254    */
255   /* For example:
256    * input       C3 80 68 E2 80 20 A6 F0 A0 80 AC 20 F0 93 80 80
257    * first_len   1  0  0  2  0  0  0  3  0  0  0  0  3  0  0  0
258    * 1st byte    8  0  0  8  0  0  0  8  0  0  0  0  8  0  0  0
259    * 2nd byte    0  1  0  0  2  0  0  0  3  0  0  0  0  3  0  0 // Shift + sub
260    * 3rd byte    0  0  0  0  0  1  0  0  0  2  0  0  0  0  2  0 // Shift + sub
261    * 4th byte    0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  1 // Shift + sub
262    * Index       8  1  0  8  2  1  0  8  3  2  1  0  8  3  2  1 // OR of results
263    */
264 
265   /* Checking for errors:
266    * Error checking is done by looking up the high nibble (4 bits) of each byte
267      against an error checking table.
268    * Because the lookup value for the second byte depends of the value of the
269      first byte in codepoint, we use saturated operations to adjust the index.
270    * Specifically we need to add 2 for E0, 3 for ED, 3 for F0 and 4 for F4 to
271      match the correct index.
272        * If we subtract from all bytes EF then EO -> 241, ED -> 254, F0 -> 1,
273          F4 -> 5
274        * Do saturating sub 240, then E0 -> 1, ED -> 14 and we can do lookup to
275          match the adjustment
276        * Add saturating 112, then F0 -> 113, F4 -> 117, all that were > 16 will
277          be more 128 and lookup in ef_fe_table will return 0 but for F0
278          and F4 it will be 4 and 5 accordingly
279    */
280   /*
281    * Then just check the appropriate ranges with greater/smaller equal
282      instructions. Check tail with a naive algorithm.
283    * To save from previous 16 byte checks we just align previous_first_len to
284      get correct continuations of the codepoints.
285    */
286 
287   /*
288    * Map high nibble of "First Byte" to legal character length minus 1
289    * 0x00 ~ 0xBF --> 0
290    * 0xC0 ~ 0xDF --> 1
291    * 0xE0 ~ 0xEF --> 2
292    * 0xF0 ~ 0xFF --> 3
293    */
294   const __m128i first_len_table =
295       _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3);
296 
297   /* Map "First Byte" to 8-th item of range table (0xC2 ~ 0xF4) */
298   const __m128i first_range_table =
299       _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8);
300 
301   /*
302    * Range table, map range index to min and max values
303    */
304   const __m128i range_min_table =
305       _mm_setr_epi8(0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80, 0xC2, 0x7F,
306                     0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F);
307 
308   const __m128i range_max_table =
309       _mm_setr_epi8(0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F, 0xF4, 0x80,
310                     0x80, 0x80, 0x80, 0x80, 0x80, 0x80);
311 
312   /*
313    * Tables for fast handling of four special First Bytes(E0,ED,F0,F4), after
314    * which the Second Byte are not 80~BF. It contains "range index adjustment".
315    * +------------+---------------+------------------+----------------+
316    * | First Byte | original range| range adjustment | adjusted range |
317    * +------------+---------------+------------------+----------------+
318    * | E0         | 2             | 2                | 4              |
319    * +------------+---------------+------------------+----------------+
320    * | ED         | 2             | 3                | 5              |
321    * +------------+---------------+------------------+----------------+
322    * | F0         | 3             | 3                | 6              |
323    * +------------+---------------+------------------+----------------+
324    * | F4         | 4             | 4                | 8              |
325    * +------------+---------------+------------------+----------------+
326    */
327 
328   /* df_ee_table[1] -> E0, df_ee_table[14] -> ED as ED - E0 = 13 */
329   // The values represent the adjustment in the Range Index table for a correct
330   // index.
331   const __m128i df_ee_table =
332       _mm_setr_epi8(0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0);
333 
334   /* ef_fe_table[1] -> F0, ef_fe_table[5] -> F4, F4 - F0 = 4 */
335   // The values represent the adjustment in the Range Index table for a correct
336   // index.
337   const __m128i ef_fe_table =
338       _mm_setr_epi8(0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
339 
340   __m128i prev_input = _mm_set1_epi8(0);
341   __m128i prev_first_len = _mm_set1_epi8(0);
342   __m128i error = _mm_set1_epi8(0);
343   while (end - data >= 16) {
344     const __m128i input =
345         _mm_loadu_si128((const __m128i*)(data));
346 
347     /* high_nibbles = input >> 4 */
348     const __m128i high_nibbles =
349         _mm_and_si128(_mm_srli_epi16(input, 4), _mm_set1_epi8(0x0F));
350 
351     /* first_len = legal character length minus 1 */
352     /* 0 for 00~7F, 1 for C0~DF, 2 for E0~EF, 3 for F0~FF */
353     /* first_len = first_len_table[high_nibbles] */
354     __m128i first_len = _mm_shuffle_epi8(first_len_table, high_nibbles);
355 
356     /* First Byte: set range index to 8 for bytes within 0xC0 ~ 0xFF */
357     /* range = first_range_table[high_nibbles] */
358     __m128i range = _mm_shuffle_epi8(first_range_table, high_nibbles);
359 
360     /* Second Byte: set range index to first_len */
361     /* 0 for 00~7F, 1 for C0~DF, 2 for E0~EF, 3 for F0~FF */
362     /* range |= (first_len, prev_first_len) << 1 byte */
363     range = _mm_or_si128(range, _mm_alignr_epi8(first_len, prev_first_len, 15));
364 
365     /* Third Byte: set range index to saturate_sub(first_len, 1) */
366     /* 0 for 00~7F, 0 for C0~DF, 1 for E0~EF, 2 for F0~FF */
367     __m128i tmp1;
368     __m128i tmp2;
369     /* tmp1 = saturate_sub(first_len, 1) */
370     tmp1 = _mm_subs_epu8(first_len, _mm_set1_epi8(1));
371     /* tmp2 = saturate_sub(prev_first_len, 1) */
372     tmp2 = _mm_subs_epu8(prev_first_len, _mm_set1_epi8(1));
373     /* range |= (tmp1, tmp2) << 2 bytes */
374     range = _mm_or_si128(range, _mm_alignr_epi8(tmp1, tmp2, 14));
375 
376     /* Fourth Byte: set range index to saturate_sub(first_len, 2) */
377     /* 0 for 00~7F, 0 for C0~DF, 0 for E0~EF, 1 for F0~FF */
378     /* tmp1 = saturate_sub(first_len, 2) */
379     tmp1 = _mm_subs_epu8(first_len, _mm_set1_epi8(2));
380     /* tmp2 = saturate_sub(prev_first_len, 2) */
381     tmp2 = _mm_subs_epu8(prev_first_len, _mm_set1_epi8(2));
382     /* range |= (tmp1, tmp2) << 3 bytes */
383     range = _mm_or_si128(range, _mm_alignr_epi8(tmp1, tmp2, 13));
384 
385     /*
386      * Now we have below range indices calculated
387      * Correct cases:
388      * - 8 for C0~FF
389      * - 3 for 1st byte after F0~FF
390      * - 2 for 1st byte after E0~EF or 2nd byte after F0~FF
391      * - 1 for 1st byte after C0~DF or 2nd byte after E0~EF or
392      *         3rd byte after F0~FF
393      * - 0 for others
394      * Error cases:
395      *   >9 for non ascii First Byte overlapping
396      *   E.g., F1 80 C2 90 --> 8 3 10 2, where 10 indicates error
397      */
398 
399     /* Adjust Second Byte range for special First Bytes(E0,ED,F0,F4) */
400     /* Overlaps lead to index 9~15, which are illegal in range table */
401     __m128i shift1;
402     __m128i pos;
403     __m128i range2;
404     /* shift1 = (input, prev_input) << 1 byte */
405     shift1 = _mm_alignr_epi8(input, prev_input, 15);
406     pos = _mm_sub_epi8(shift1, _mm_set1_epi8(0xEF));
407     /*
408      * shift1:  | EF  F0 ... FE | FF  00  ... ...  DE | DF  E0 ... EE |
409      * pos:     | 0   1      15 | 16  17           239| 240 241    255|
410      * pos-240: | 0   0      0  | 0   0            0  | 0   1      15 |
411      * pos+112: | 112 113    127|       >= 128        |     >= 128    |
412      */
413     tmp1 = _mm_subs_epu8(pos, _mm_set1_epi8(-16));
414     range2 = _mm_shuffle_epi8(df_ee_table, tmp1);
415     tmp2 = _mm_adds_epu8(pos, _mm_set1_epi8(112));
416     range2 = _mm_add_epi8(range2, _mm_shuffle_epi8(ef_fe_table, tmp2));
417 
418     range = _mm_add_epi8(range, range2);
419 
420     /* Load min and max values per calculated range index */
421     __m128i min_range = _mm_shuffle_epi8(range_min_table, range);
422     __m128i max_range = _mm_shuffle_epi8(range_max_table, range);
423 
424     /* Check value range */
425     if (return_position) {
426       error = _mm_cmplt_epi8(input, min_range);
427       error = _mm_or_si128(error, _mm_cmpgt_epi8(input, max_range));
428       /* 5% performance drop from this conditional branch */
429       if (!_mm_testz_si128(error, error)) {
430         break;
431       }
432     } else {
433       error = _mm_or_si128(error, _mm_cmplt_epi8(input, min_range));
434       error = _mm_or_si128(error, _mm_cmpgt_epi8(input, max_range));
435     }
436 
437     prev_input = input;
438     prev_first_len = first_len;
439 
440     data += 16;
441   }
442   /* If we got to the end, we don't need to skip any bytes backwards */
443   if (return_position && (data - (end - len)) == 0) {
444     return utf8_range_ValidateUTF8Naive(data, end, return_position);
445   }
446   /* Find previous codepoint (not 80~BF) */
447   data -= utf8_range_CodepointSkipBackwards(_mm_extract_epi32(prev_input, 3));
448   if (return_position) {
449     return (data - (end - len)) +
450            utf8_range_ValidateUTF8Naive(data, end, return_position);
451   }
452   /* Test if there was any error */
453   if (!_mm_testz_si128(error, error)) {
454     return 0;
455   }
456   /* Check the tail */
457   return utf8_range_ValidateUTF8Naive(data, end, return_position);
458 #endif
459 }
460 
utf8_range_IsValid(const char * data,size_t len)461 int utf8_range_IsValid(const char* data, size_t len) {
462   return utf8_range_Validate(data, len, /*return_position=*/0) != 0;
463 }
464 
utf8_range_ValidPrefix(const char * data,size_t len)465 size_t utf8_range_ValidPrefix(const char* data, size_t len) {
466   return utf8_range_Validate(data, len, /*return_position=*/1);
467 }
468