xref: /aosp_15_r20/external/grpc-grpc/third_party/utf8_range/lemire-neon.c (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Adapted from https://github.com/lemire/fastvalidate-utf-8
2 
3 #ifdef __aarch64__
4 
5 #include <stdio.h>
6 #include <stddef.h>
7 #include <stdint.h>
8 #include <string.h>
9 #include <inttypes.h>
10 #include <arm_neon.h>
11 
12 /*
13  * legal utf-8 byte sequence
14  * http://www.unicode.org/versions/Unicode6.0.0/ch03.pdf - page 94
15  *
16  *  Code Points        1st       2s       3s       4s
17  * U+0000..U+007F     00..7F
18  * U+0080..U+07FF     C2..DF   80..BF
19  * U+0800..U+0FFF     E0       A0..BF   80..BF
20  * U+1000..U+CFFF     E1..EC   80..BF   80..BF
21  * U+D000..U+D7FF     ED       80..9F   80..BF
22  * U+E000..U+FFFF     EE..EF   80..BF   80..BF
23  * U+10000..U+3FFFF   F0       90..BF   80..BF   80..BF
24  * U+40000..U+FFFFF   F1..F3   80..BF   80..BF   80..BF
25  * U+100000..U+10FFFF F4       80..8F   80..BF   80..BF
26  *
27  */
28 
29 #if 0
30 static void print128(const char *s, const int8x16_t *v128)
31 {
32     int8_t v8[16];
33     vst1q_s8(v8, *v128);
34 
35     if (s)
36         printf("%s:\t", s);
37     for (int i = 0; i < 16; ++i)
38         printf("%02x ", (unsigned char)v8[i]);
39     printf("\n");
40 }
41 #endif
42 
43 // all byte values must be no larger than 0xF4
checkSmallerThan0xF4(int8x16_t current_bytes,int8x16_t * has_error)44 static inline void checkSmallerThan0xF4(int8x16_t current_bytes,
45                                         int8x16_t *has_error) {
46   // unsigned, saturates to 0 below max
47   *has_error = vorrq_s8(*has_error,
48           vreinterpretq_s8_u8(vqsubq_u8(vreinterpretq_u8_s8(current_bytes), vdupq_n_u8(0xF4))));
49 }
50 
51 static const int8_t _nibbles[] = {
52   1, 1, 1, 1, 1, 1, 1, 1, // 0xxx (ASCII)
53   0, 0, 0, 0,             // 10xx (continuation)
54   2, 2,                   // 110x
55   3,                      // 1110
56   4, // 1111, next should be 0 (not checked here)
57 };
58 
continuationLengths(int8x16_t high_nibbles)59 static inline int8x16_t continuationLengths(int8x16_t high_nibbles) {
60   return vqtbl1q_s8(vld1q_s8(_nibbles), vreinterpretq_u8_s8(high_nibbles));
61 }
62 
carryContinuations(int8x16_t initial_lengths,int8x16_t previous_carries)63 static inline int8x16_t carryContinuations(int8x16_t initial_lengths,
64                                          int8x16_t previous_carries) {
65 
66   int8x16_t right1 =
67      vreinterpretq_s8_u8(vqsubq_u8(vreinterpretq_u8_s8(vextq_s8(previous_carries, initial_lengths, 16 - 1)),
68                     vdupq_n_u8(1)));
69   int8x16_t sum = vaddq_s8(initial_lengths, right1);
70 
71   int8x16_t right2 = vreinterpretq_s8_u8(vqsubq_u8(vreinterpretq_u8_s8(vextq_s8(previous_carries, sum, 16 - 2)),
72                                  vdupq_n_u8(2)));
73   return vaddq_s8(sum, right2);
74 }
75 
checkContinuations(int8x16_t initial_lengths,int8x16_t carries,int8x16_t * has_error)76 static inline void checkContinuations(int8x16_t initial_lengths, int8x16_t carries,
77                                       int8x16_t *has_error) {
78 
79   // overlap || underlap
80   // carry > length && length > 0 || !(carry > length) && !(length > 0)
81   // (carries > length) == (lengths > 0)
82   uint8x16_t overunder =
83       vceqq_u8(vcgtq_s8(carries, initial_lengths),
84                      vcgtq_s8(initial_lengths, vdupq_n_s8(0)));
85 
86   *has_error = vorrq_s8(*has_error, vreinterpretq_s8_u8(overunder));
87 }
88 
89 // when 0xED is found, next byte must be no larger than 0x9F
90 // when 0xF4 is found, next byte must be no larger than 0x8F
91 // next byte must be continuation, ie sign bit is set, so signed < is ok
checkFirstContinuationMax(int8x16_t current_bytes,int8x16_t off1_current_bytes,int8x16_t * has_error)92 static inline void checkFirstContinuationMax(int8x16_t current_bytes,
93                                              int8x16_t off1_current_bytes,
94                                              int8x16_t *has_error) {
95   uint8x16_t maskED = vceqq_s8(off1_current_bytes, vdupq_n_s8(0xED));
96   uint8x16_t maskF4 = vceqq_s8(off1_current_bytes, vdupq_n_s8(0xF4));
97 
98   uint8x16_t badfollowED =
99       vandq_u8(vcgtq_s8(current_bytes, vdupq_n_s8(0x9F)), maskED);
100   uint8x16_t badfollowF4 =
101       vandq_u8(vcgtq_s8(current_bytes, vdupq_n_s8(0x8F)), maskF4);
102 
103   *has_error = vorrq_s8(*has_error, vreinterpretq_s8_u8(vorrq_u8(badfollowED, badfollowF4)));
104 }
105 
106 static const int8_t _initial_mins[] = {
107   -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
108   -128, -128, // 10xx => false
109   0xC2, -128, // 110x
110   0xE1,       // 1110
111   0xF1,
112 };
113 
114 static const int8_t _second_mins[] = {
115   -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
116   -128, -128, // 10xx => false
117   127, 127,   // 110x => true
118   0xA0,       // 1110
119   0x90,
120 };
121 
122 // map off1_hibits => error condition
123 // hibits     off1    cur
124 // C       => < C2 && true
125 // E       => < E1 && < A0
126 // F       => < F1 && < 90
127 // else      false && false
checkOverlong(int8x16_t current_bytes,int8x16_t off1_current_bytes,int8x16_t hibits,int8x16_t previous_hibits,int8x16_t * has_error)128 static inline void checkOverlong(int8x16_t current_bytes,
129                                  int8x16_t off1_current_bytes, int8x16_t hibits,
130                                  int8x16_t previous_hibits, int8x16_t *has_error) {
131   int8x16_t off1_hibits = vextq_s8(previous_hibits, hibits, 16 - 1);
132   int8x16_t initial_mins = vqtbl1q_s8(vld1q_s8(_initial_mins), vreinterpretq_u8_s8(off1_hibits));
133 
134   uint8x16_t initial_under = vcgtq_s8(initial_mins, off1_current_bytes);
135 
136   int8x16_t second_mins = vqtbl1q_s8(vld1q_s8(_second_mins), vreinterpretq_u8_s8(off1_hibits));
137   uint8x16_t second_under = vcgtq_s8(second_mins, current_bytes);
138   *has_error =
139      vorrq_s8(*has_error, vreinterpretq_s8_u8(vandq_u8(initial_under, second_under)));
140 }
141 
142 struct processed_utf_bytes {
143   int8x16_t rawbytes;
144   int8x16_t high_nibbles;
145   int8x16_t carried_continuations;
146 };
147 
count_nibbles(int8x16_t bytes,struct processed_utf_bytes * answer)148 static inline void count_nibbles(int8x16_t bytes,
149                                  struct processed_utf_bytes *answer) {
150   answer->rawbytes = bytes;
151   answer->high_nibbles =
152     vreinterpretq_s8_u8(vshrq_n_u8(vreinterpretq_u8_s8(bytes), 4));
153 }
154 
155 // check whether the current bytes are valid UTF-8
156 // at the end of the function, previous gets updated
157 static inline struct processed_utf_bytes
checkUTF8Bytes(int8x16_t current_bytes,struct processed_utf_bytes * previous,int8x16_t * has_error)158 checkUTF8Bytes(int8x16_t current_bytes, struct processed_utf_bytes *previous,
159                int8x16_t *has_error) {
160   struct processed_utf_bytes pb;
161   count_nibbles(current_bytes, &pb);
162 
163   checkSmallerThan0xF4(current_bytes, has_error);
164 
165   int8x16_t initial_lengths = continuationLengths(pb.high_nibbles);
166 
167   pb.carried_continuations =
168       carryContinuations(initial_lengths, previous->carried_continuations);
169 
170   checkContinuations(initial_lengths, pb.carried_continuations, has_error);
171 
172   int8x16_t off1_current_bytes =
173     vextq_s8(previous->rawbytes, pb.rawbytes, 16 - 1);
174   checkFirstContinuationMax(current_bytes, off1_current_bytes, has_error);
175 
176   checkOverlong(current_bytes, off1_current_bytes, pb.high_nibbles,
177                 previous->high_nibbles, has_error);
178   return pb;
179 }
180 
181 static const int8_t _verror[] = {9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 1};
182 
183 /* Return 0 on success, -1 on error */
utf8_lemire(const unsigned char * src,int len)184 int utf8_lemire(const unsigned char *src, int len) {
185   size_t i = 0;
186   int8x16_t has_error = vdupq_n_s8(0);
187   struct processed_utf_bytes previous = {.rawbytes = vdupq_n_s8(0),
188                                          .high_nibbles = vdupq_n_s8(0),
189                                          .carried_continuations =
190                                              vdupq_n_s8(0)};
191   if (len >= 16) {
192     for (; i <= len - 16; i += 16) {
193       int8x16_t current_bytes = vld1q_s8((int8_t*)(src + i));
194       previous = checkUTF8Bytes(current_bytes, &previous, &has_error);
195     }
196   }
197 
198   // last part
199   if (i < len) {
200     char buffer[16];
201     memset(buffer, 0, 16);
202     memcpy(buffer, src + i, len - i);
203     int8x16_t current_bytes = vld1q_s8((int8_t *)buffer);
204     previous = checkUTF8Bytes(current_bytes, &previous, &has_error);
205   } else {
206     has_error =
207         vorrq_s8(vreinterpretq_s8_u8(vcgtq_s8(previous.carried_continuations,
208                                     vld1q_s8(_verror))),
209                      has_error);
210   }
211 
212   return vmaxvq_u8(vreinterpretq_u8_s8(has_error)) == 0 ? 0 : -1;
213 }
214 
215 #endif
216