1 #include <sys/time.h>
2
3 #include <algorithm>
4 #include <cassert>
5 #include <cstdint>
6 #include <cstdio>
7 #include <cstring>
8 #include <vector>
9
ascii_std(const uint8_t * data,int len)10 static inline int ascii_std(const uint8_t *data, int len) {
11 return !std::any_of(data, data + len, [](int8_t b) { return b < 0; });
12 }
13
ascii_u64(const uint8_t * data,int len)14 static inline int ascii_u64(const uint8_t *data, int len) {
15 uint8_t orall = 0;
16
17 if (len >= 16) {
18 uint64_t or1 = 0, or2 = 0;
19 const uint8_t *data2 = data + 8;
20
21 do {
22 or1 |= *(const uint64_t *)data;
23 or2 |= *(const uint64_t *)data2;
24 data += 16;
25 data2 += 16;
26 len -= 16;
27 } while (len >= 16);
28
29 /*
30 * Idea from Benny Halevy <[email protected]>
31 * - 7-th bit set ==> orall = !(non-zero) - 1 = 0 - 1 = 0xFF
32 * - 7-th bit clear ==> orall = !0 - 1 = 1 - 1 = 0x00
33 */
34 orall = !((or1 | or2) & 0x8080808080808080ULL) - 1;
35 }
36
37 while (len--) orall |= *data++;
38
39 return orall < 0x80;
40 }
41
42 #if defined(__x86_64__)
43 #include <x86intrin.h>
44
ascii_simd(const uint8_t * data,int len)45 static inline int ascii_simd(const uint8_t *data, int len) {
46 if (len >= 32) {
47 const uint8_t *data2 = data + 16;
48
49 __m128i or1 = _mm_set1_epi8(0), or2 = or1;
50
51 while (len >= 32) {
52 __m128i input1 = _mm_loadu_si128((const __m128i *)data);
53 __m128i input2 = _mm_loadu_si128((const __m128i *)data2);
54
55 or1 = _mm_or_si128(or1, input1);
56 or2 = _mm_or_si128(or2, input2);
57
58 data += 32;
59 data2 += 32;
60 len -= 32;
61 }
62
63 or1 = _mm_or_si128(or1, or2);
64 if (_mm_movemask_epi8(_mm_cmplt_epi8(or1, _mm_set1_epi8(0)))) return 0;
65 }
66
67 return ascii_u64(data, len);
68 }
69
70 #elif defined(__aarch64__)
71 #include <arm_neon.h>
72
ascii_simd(const uint8_t * data,int len)73 static inline int ascii_simd(const uint8_t *data, int len) {
74 if (len >= 32) {
75 const uint8_t *data2 = data + 16;
76
77 uint8x16_t or1 = vdupq_n_u8(0), or2 = or1;
78
79 while (len >= 32) {
80 const uint8x16_t input1 = vld1q_u8(data);
81 const uint8x16_t input2 = vld1q_u8(data2);
82
83 or1 = vorrq_u8(or1, input1);
84 or2 = vorrq_u8(or2, input2);
85
86 data += 32;
87 data2 += 32;
88 len -= 32;
89 }
90
91 or1 = vorrq_u8(or1, or2);
92 if (vmaxvq_u8(or1) >= 0x80) return 0;
93 }
94
95 return ascii_u64(data, len);
96 }
97
98 #endif
99
100 struct ftab {
101 const char *name;
102 int (*func)(const uint8_t *data, int len);
103 };
104
105 static const std::vector<ftab> _f = {
106 {
107 .name = "std",
108 .func = ascii_std,
109 },
110 {
111 .name = "u64",
112 .func = ascii_u64,
113 },
114 {
115 .name = "simd",
116 .func = ascii_simd,
117 },
118 };
119
load_test_buf(uint8_t * data,int len)120 static void load_test_buf(uint8_t *data, int len) {
121 uint8_t v = 0;
122
123 for (int i = 0; i < len; ++i) {
124 data[i] = v++;
125 v &= 0x7F;
126 }
127 }
128
bench(const struct ftab & f,const uint8_t * data,int len)129 static void bench(const struct ftab &f, const uint8_t *data, int len) {
130 const int loops = 1024 * 1024 * 1024 / len;
131 int ret = 1;
132 double time_aligned, time_unaligned, size;
133 struct timeval tv1, tv2;
134
135 fprintf(stderr, "bench %s (%d bytes)... ", f.name, len);
136
137 /* aligned */
138 gettimeofday(&tv1, 0);
139 for (int i = 0; i < loops; ++i) ret &= f.func(data, len);
140 gettimeofday(&tv2, 0);
141 time_aligned = tv2.tv_usec - tv1.tv_usec;
142 time_aligned = time_aligned / 1000000 + tv2.tv_sec - tv1.tv_sec;
143
144 /* unaligned */
145 gettimeofday(&tv1, 0);
146 for (int i = 0; i < loops; ++i) ret &= f.func(data + 1, len);
147 gettimeofday(&tv2, 0);
148 time_unaligned = tv2.tv_usec - tv1.tv_usec;
149 time_unaligned = time_unaligned / 1000000 + tv2.tv_sec - tv1.tv_sec;
150
151 printf("%s ", ret ? "pass" : "FAIL");
152
153 size = ((double)len * loops) / (1024 * 1024);
154 printf("%.0f/%.0f MB/s\n", size / time_aligned, size / time_unaligned);
155 }
156
test(const struct ftab & f,uint8_t * data,int len)157 static void test(const struct ftab &f, uint8_t *data, int len) {
158 int error = 0;
159
160 fprintf(stderr, "test %s (%d bytes)... ", f.name, len);
161
162 /* positive */
163 error |= !f.func(data, len);
164
165 /* negative */
166 if (len < 100 * 1024) {
167 for (int i = 0; i < len; ++i) {
168 data[i] += 0x80;
169 error |= f.func(data, len);
170 data[i] -= 0x80;
171 }
172 }
173
174 printf("%s\n", error ? "FAIL" : "pass");
175 }
176
177 /* ./ascii [test|bench] [alg] */
main(int argc,const char * argv[])178 int main(int argc, const char *argv[]) {
179 int do_test = 1, do_bench = 1;
180 const char *alg = NULL;
181
182 if (argc > 1) {
183 do_bench &= !!strcmp(argv[1], "test");
184 do_test &= !!strcmp(argv[1], "bench");
185 }
186
187 if (do_bench && argc > 2) alg = argv[2];
188
189 const std::vector<int> size = {
190 9, 16 + 1, 32 - 1, 128 + 1,
191 1024 + 15, 16 * 1024 + 1, 64 * 1024 + 15, 1024 * 1024};
192
193 int max_size = *std::max_element(size.begin(), size.end());
194 uint8_t *_data = new uint8_t[max_size + 1];
195 assert(((uintptr_t)_data & 7) == 0);
196 uint8_t *data = _data + 1; /* Unalign buffer address */
197
198 _data[0] = 0;
199 load_test_buf(data, max_size);
200
201 if (do_test) {
202 printf("==================== Test ====================\n");
203 for (int sz : size) {
204 for (auto &f : _f) {
205 test(f, data, sz);
206 }
207 }
208 }
209
210 if (do_bench) {
211 printf("==================== Bench ====================\n");
212 for (int sz : size) {
213 for (auto &f : _f) {
214 if (!alg || strcmp(alg, f.name) == 0) bench(f, _data, sz);
215 }
216 printf("-----------------------------------------------\n");
217 }
218 }
219
220 delete _data;
221 return 0;
222 }
223