1 //// --------------------------
2 //// ATTENTION:
3 //// THIS CODE IS AUTOGENERATED
4 //// BY hp_emblookup_codegen.py
5 //// DO NOT MODIFY!!!
6 //// --------------------------
7
8 #include <c10/util/Half.h>
9 #include <c10/util/BFloat16.h>
10 #include <immintrin.h>
11 namespace caffe2 {
12
13 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_float_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)14 static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(
15 const int64_t block_size,
16 const int64_t output_size,
17 const int64_t index_size,
18 const int64_t data_size,
19 const float* input,
20 const int* indices,
21 const int* offsets,
22 const float* weights,
23 const float* scale_bias,
24 bool normalize_by_lengths,
25 float* out) {
26 const int prefdist_T0 = 16;
27 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
28 const int fused_block_size = block_size + 0;
29 int64_t dataInd = 0;
30 if (block_size == 128) {
31 // unrolling 16 times
32 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
33 float* op = &out[rangeIndex * block_size];
34 __m256 vop0 = _mm256_setzero_ps();
35 __m256 vop8 = _mm256_setzero_ps();
36 __m256 vop16 = _mm256_setzero_ps();
37 __m256 vop24 = _mm256_setzero_ps();
38 __m256 vop32 = _mm256_setzero_ps();
39 __m256 vop40 = _mm256_setzero_ps();
40 __m256 vop48 = _mm256_setzero_ps();
41 __m256 vop56 = _mm256_setzero_ps();
42 __m256 vop64 = _mm256_setzero_ps();
43 __m256 vop72 = _mm256_setzero_ps();
44 __m256 vop80 = _mm256_setzero_ps();
45 __m256 vop88 = _mm256_setzero_ps();
46 __m256 vop96 = _mm256_setzero_ps();
47 __m256 vop104 = _mm256_setzero_ps();
48 __m256 vop112 = _mm256_setzero_ps();
49 __m256 vop120 = _mm256_setzero_ps();
50 if (dataInd != offsets[rangeIndex] - offsets[0]) {
51 return false;
52 }
53 int64_t end_offset = offsets[rangeIndex + 1];
54 int64_t length = end_offset - offsets[rangeIndex];
55 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
56 ++dataInd) {
57 const int idx = indices[dataInd];
58 if (idx < 0 || idx >= data_size) {
59 return false;
60 }
61 float wgt = 1.f;
62 if (weights) {
63 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
64 }
65 __m256 vwgt = _mm256_set1_ps(wgt);
66 const float* ip = &input[idx * fused_block_size];
67 const int next_T0 = (dataInd < index_size - prefdist_T0)
68 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
69 ? (dataInd + prefdist_T0)
70 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
71 : dataInd;
72 const int idx_pref_T0 = indices[next_T0];
73 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
74 return false;
75 }
76 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
77 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
78 _mm_prefetch(
79 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
80 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
81 // skip unnecessary prefetch of (&ip_next_T0[8])
82 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
83 _mm_prefetch(
84 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
85 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
86 // skip unnecessary prefetch of (&ip_next_T0[24])
87 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
88 _mm_prefetch(
89 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
90 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
91 // skip unnecessary prefetch of (&ip_next_T0[40])
92 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
93 _mm_prefetch(
94 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
95 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
96 // skip unnecessary prefetch of (&ip_next_T0[56])
97 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
98 _mm_prefetch(
99 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
100 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
101 // skip unnecessary prefetch of (&ip_next_T0[72])
102 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
103 _mm_prefetch(
104 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
105 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
106 // skip unnecessary prefetch of (&ip_next_T0[88])
107 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
108 _mm_prefetch(
109 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
110 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
111 // skip unnecessary prefetch of (&ip_next_T0[104])
112 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
113 _mm_prefetch(
114 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
115 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
116 // skip unnecessary prefetch of (&ip_next_T0[120])
117 }
118 if (!normalize_by_lengths || length == 0) {
119 _mm256_storeu_ps(&op[0], vop0);
120 _mm256_storeu_ps(&op[8], vop8);
121 _mm256_storeu_ps(&op[16], vop16);
122 _mm256_storeu_ps(&op[24], vop24);
123 _mm256_storeu_ps(&op[32], vop32);
124 _mm256_storeu_ps(&op[40], vop40);
125 _mm256_storeu_ps(&op[48], vop48);
126 _mm256_storeu_ps(&op[56], vop56);
127 _mm256_storeu_ps(&op[64], vop64);
128 _mm256_storeu_ps(&op[72], vop72);
129 _mm256_storeu_ps(&op[80], vop80);
130 _mm256_storeu_ps(&op[88], vop88);
131 _mm256_storeu_ps(&op[96], vop96);
132 _mm256_storeu_ps(&op[104], vop104);
133 _mm256_storeu_ps(&op[112], vop112);
134 _mm256_storeu_ps(&op[120], vop120);
135 } else {
136 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
137 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
138 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
139 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
140 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
141 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
142 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
143 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
144 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
145 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
146 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
147 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
148 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
149 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
150 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
151 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
152 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
153 }
154 }
155 } else if (block_size == 64) {
156 // unrolling 8 times
157 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
158 float* op = &out[rangeIndex * block_size];
159 __m256 vop0 = _mm256_setzero_ps();
160 __m256 vop8 = _mm256_setzero_ps();
161 __m256 vop16 = _mm256_setzero_ps();
162 __m256 vop24 = _mm256_setzero_ps();
163 __m256 vop32 = _mm256_setzero_ps();
164 __m256 vop40 = _mm256_setzero_ps();
165 __m256 vop48 = _mm256_setzero_ps();
166 __m256 vop56 = _mm256_setzero_ps();
167 if (dataInd != offsets[rangeIndex] - offsets[0]) {
168 return false;
169 }
170 int64_t end_offset = offsets[rangeIndex + 1];
171 int64_t length = end_offset - offsets[rangeIndex];
172 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
173 ++dataInd) {
174 const int idx = indices[dataInd];
175 if (idx < 0 || idx >= data_size) {
176 return false;
177 }
178 float wgt = 1.f;
179 if (weights) {
180 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
181 }
182 __m256 vwgt = _mm256_set1_ps(wgt);
183 const float* ip = &input[idx * fused_block_size];
184 const int next_T0 = (dataInd < index_size - prefdist_T0)
185 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
186 ? (dataInd + prefdist_T0)
187 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
188 : dataInd;
189 const int idx_pref_T0 = indices[next_T0];
190 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
191 return false;
192 }
193 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
194 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
195 _mm_prefetch(
196 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
197 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
198 // skip unnecessary prefetch of (&ip_next_T0[8])
199 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
200 _mm_prefetch(
201 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
202 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
203 // skip unnecessary prefetch of (&ip_next_T0[24])
204 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
205 _mm_prefetch(
206 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
207 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
208 // skip unnecessary prefetch of (&ip_next_T0[40])
209 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
210 _mm_prefetch(
211 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
212 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
213 // skip unnecessary prefetch of (&ip_next_T0[56])
214 }
215 if (!normalize_by_lengths || length == 0) {
216 _mm256_storeu_ps(&op[0], vop0);
217 _mm256_storeu_ps(&op[8], vop8);
218 _mm256_storeu_ps(&op[16], vop16);
219 _mm256_storeu_ps(&op[24], vop24);
220 _mm256_storeu_ps(&op[32], vop32);
221 _mm256_storeu_ps(&op[40], vop40);
222 _mm256_storeu_ps(&op[48], vop48);
223 _mm256_storeu_ps(&op[56], vop56);
224 } else {
225 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
226 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
227 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
228 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
229 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
230 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
231 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
232 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
233 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
234 }
235 }
236 } else if (block_size == 32) {
237 // unrolling 4 times
238 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
239 float* op = &out[rangeIndex * block_size];
240 __m256 vop0 = _mm256_setzero_ps();
241 __m256 vop8 = _mm256_setzero_ps();
242 __m256 vop16 = _mm256_setzero_ps();
243 __m256 vop24 = _mm256_setzero_ps();
244 if (dataInd != offsets[rangeIndex] - offsets[0]) {
245 return false;
246 }
247 int64_t end_offset = offsets[rangeIndex + 1];
248 int64_t length = end_offset - offsets[rangeIndex];
249 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
250 ++dataInd) {
251 const int idx = indices[dataInd];
252 if (idx < 0 || idx >= data_size) {
253 return false;
254 }
255 float wgt = 1.f;
256 if (weights) {
257 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
258 }
259 __m256 vwgt = _mm256_set1_ps(wgt);
260 const float* ip = &input[idx * fused_block_size];
261 const int next_T0 = (dataInd < index_size - prefdist_T0)
262 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
263 ? (dataInd + prefdist_T0)
264 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
265 : dataInd;
266 const int idx_pref_T0 = indices[next_T0];
267 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
268 return false;
269 }
270 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
271 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
272 _mm_prefetch(
273 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
274 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
275 // skip unnecessary prefetch of (&ip_next_T0[8])
276 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
277 _mm_prefetch(
278 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
279 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
280 // skip unnecessary prefetch of (&ip_next_T0[24])
281 }
282 if (!normalize_by_lengths || length == 0) {
283 _mm256_storeu_ps(&op[0], vop0);
284 _mm256_storeu_ps(&op[8], vop8);
285 _mm256_storeu_ps(&op[16], vop16);
286 _mm256_storeu_ps(&op[24], vop24);
287 } else {
288 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
289 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
290 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
291 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
292 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
293 }
294 }
295 } else if (block_size == 16) {
296 // unrolling 2 times
297 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
298 float* op = &out[rangeIndex * block_size];
299 __m256 vop0 = _mm256_setzero_ps();
300 __m256 vop8 = _mm256_setzero_ps();
301 if (dataInd != offsets[rangeIndex] - offsets[0]) {
302 return false;
303 }
304 int64_t end_offset = offsets[rangeIndex + 1];
305 int64_t length = end_offset - offsets[rangeIndex];
306 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
307 ++dataInd) {
308 const int idx = indices[dataInd];
309 if (idx < 0 || idx >= data_size) {
310 return false;
311 }
312 float wgt = 1.f;
313 if (weights) {
314 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
315 }
316 __m256 vwgt = _mm256_set1_ps(wgt);
317 const float* ip = &input[idx * fused_block_size];
318 const int next_T0 = (dataInd < index_size - prefdist_T0)
319 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
320 ? (dataInd + prefdist_T0)
321 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
322 : dataInd;
323 const int idx_pref_T0 = indices[next_T0];
324 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
325 return false;
326 }
327 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
328 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
329 _mm_prefetch(
330 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
331 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
332 // skip unnecessary prefetch of (&ip_next_T0[8])
333 }
334 if (!normalize_by_lengths || length == 0) {
335 _mm256_storeu_ps(&op[0], vop0);
336 _mm256_storeu_ps(&op[8], vop8);
337 } else {
338 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
339 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
340 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
341 }
342 }
343 } else {
344 // generic code
345 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
346 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
347 float* op = &out[rangeIndex * block_size];
348 int64_t j = 0;
349 for (; j + 8 <= block_size; j += 8) {
350 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
351 }
352 for (; j < block_size; j++) {
353 op[j] = 0.0f;
354 }
355 if (dataInd != offsets[rangeIndex] - offsets[0]) {
356 return false;
357 }
358 int64_t end_offset = offsets[rangeIndex + 1];
359 int64_t length = end_offset - offsets[rangeIndex];
360 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
361 ++dataInd) {
362 const int idx = indices[dataInd];
363 if (idx < 0 || idx >= data_size) {
364 return false;
365 }
366 float wgt = 1.f;
367 if (weights) {
368 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
369 }
370 __m256 vwgt = _mm256_set1_ps(wgt);
371 const float* ip = &input[idx * fused_block_size];
372 const int next_T0 = (dataInd < index_size - prefdist_T0)
373 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
374 ? (dataInd + prefdist_T0)
375 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
376 : dataInd;
377 const int idx_pref_T0 = indices[next_T0];
378 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
379 return false;
380 }
381 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
382 j = 0;
383 for (; j + 8 <= block_size; j += 8) {
384 _mm256_storeu_ps(
385 &op[j],
386 _mm256_fmadd_ps(
387 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
388 _mm_prefetch(
389 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
390 }
391 for (; j < block_size; j++) {
392 op[j] = std::fma(wgt, ip[j], op[j]);
393 }
394 }
395 if (normalize_by_lengths && length) {
396 float len_inv = 1.0f / length;
397 __m256 vlen_inv = _mm256_set1_ps(len_inv);
398 j = 0;
399 for (; j + 8 <= block_size; j += 8) {
400 _mm256_storeu_ps(
401 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
402 }
403 for (; j < block_size; j++) {
404 op[j] = len_inv * op[j];
405 }
406 }
407 }
408 }
409 return dataInd == index_size;
410 }
EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)411 bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
412 const int64_t block_size,
413 const int64_t output_size,
414 const int64_t index_size,
415 const int64_t data_size,
416 const float* input,
417 const int* indices,
418 const int* offsets,
419 const float* weights,
420 const float* scale_bias,
421 bool normalize_by_lengths,
422 float* out) {
423 return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<false>(
424 block_size,
425 output_size,
426 index_size,
427 data_size,
428 input,
429 indices,
430 offsets,
431 weights,
432 scale_bias,
433 normalize_by_lengths,
434 out);
435 }
EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)436 bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
437 const int64_t block_size,
438 const int64_t output_size,
439 const int64_t index_size,
440 const int64_t data_size,
441 const float* input,
442 const int* indices,
443 const int* offsets,
444 const float* weights,
445 const float* scale_bias,
446 bool normalize_by_lengths,
447 float* out) {
448 return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<true>(
449 block_size,
450 output_size,
451 index_size,
452 data_size,
453 input,
454 indices,
455 offsets,
456 weights,
457 scale_bias,
458 normalize_by_lengths,
459 out);
460 }
461
462 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_float_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)463 static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma(
464 const int64_t block_size,
465 const int64_t output_size,
466 const int64_t index_size,
467 const int64_t data_size,
468 const float* input,
469 const int64_t* indices,
470 const int64_t* offsets,
471 const float* weights,
472 const float* scale_bias,
473 bool normalize_by_lengths,
474 float* out) {
475 const int64_t prefdist_T0 = 16;
476 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
477 const int64_t fused_block_size = block_size + 0;
478 int64_t dataInd = 0;
479 if (block_size == 128) {
480 // unrolling 16 times
481 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
482 float* op = &out[rangeIndex * block_size];
483 __m256 vop0 = _mm256_setzero_ps();
484 __m256 vop8 = _mm256_setzero_ps();
485 __m256 vop16 = _mm256_setzero_ps();
486 __m256 vop24 = _mm256_setzero_ps();
487 __m256 vop32 = _mm256_setzero_ps();
488 __m256 vop40 = _mm256_setzero_ps();
489 __m256 vop48 = _mm256_setzero_ps();
490 __m256 vop56 = _mm256_setzero_ps();
491 __m256 vop64 = _mm256_setzero_ps();
492 __m256 vop72 = _mm256_setzero_ps();
493 __m256 vop80 = _mm256_setzero_ps();
494 __m256 vop88 = _mm256_setzero_ps();
495 __m256 vop96 = _mm256_setzero_ps();
496 __m256 vop104 = _mm256_setzero_ps();
497 __m256 vop112 = _mm256_setzero_ps();
498 __m256 vop120 = _mm256_setzero_ps();
499 if (dataInd != offsets[rangeIndex] - offsets[0]) {
500 return false;
501 }
502 int64_t end_offset = offsets[rangeIndex + 1];
503 int64_t length = end_offset - offsets[rangeIndex];
504 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
505 ++dataInd) {
506 const int64_t idx = indices[dataInd];
507 if (idx < 0 || idx >= data_size) {
508 return false;
509 }
510 float wgt = 1.f;
511 if (weights) {
512 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
513 }
514 __m256 vwgt = _mm256_set1_ps(wgt);
515 const float* ip = &input[idx * fused_block_size];
516 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
517 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
518 ? (dataInd + prefdist_T0)
519 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
520 : dataInd;
521 const int64_t idx_pref_T0 = indices[next_T0];
522 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
523 return false;
524 }
525 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
526 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
527 _mm_prefetch(
528 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
529 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
530 // skip unnecessary prefetch of (&ip_next_T0[8])
531 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
532 _mm_prefetch(
533 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
534 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
535 // skip unnecessary prefetch of (&ip_next_T0[24])
536 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
537 _mm_prefetch(
538 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
539 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
540 // skip unnecessary prefetch of (&ip_next_T0[40])
541 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
542 _mm_prefetch(
543 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
544 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
545 // skip unnecessary prefetch of (&ip_next_T0[56])
546 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
547 _mm_prefetch(
548 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
549 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
550 // skip unnecessary prefetch of (&ip_next_T0[72])
551 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
552 _mm_prefetch(
553 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
554 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
555 // skip unnecessary prefetch of (&ip_next_T0[88])
556 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
557 _mm_prefetch(
558 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
559 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
560 // skip unnecessary prefetch of (&ip_next_T0[104])
561 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
562 _mm_prefetch(
563 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
564 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
565 // skip unnecessary prefetch of (&ip_next_T0[120])
566 }
567 if (!normalize_by_lengths || length == 0) {
568 _mm256_storeu_ps(&op[0], vop0);
569 _mm256_storeu_ps(&op[8], vop8);
570 _mm256_storeu_ps(&op[16], vop16);
571 _mm256_storeu_ps(&op[24], vop24);
572 _mm256_storeu_ps(&op[32], vop32);
573 _mm256_storeu_ps(&op[40], vop40);
574 _mm256_storeu_ps(&op[48], vop48);
575 _mm256_storeu_ps(&op[56], vop56);
576 _mm256_storeu_ps(&op[64], vop64);
577 _mm256_storeu_ps(&op[72], vop72);
578 _mm256_storeu_ps(&op[80], vop80);
579 _mm256_storeu_ps(&op[88], vop88);
580 _mm256_storeu_ps(&op[96], vop96);
581 _mm256_storeu_ps(&op[104], vop104);
582 _mm256_storeu_ps(&op[112], vop112);
583 _mm256_storeu_ps(&op[120], vop120);
584 } else {
585 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
586 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
587 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
588 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
589 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
590 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
591 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
592 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
593 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
594 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
595 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
596 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
597 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
598 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
599 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
600 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
601 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
602 }
603 }
604 } else if (block_size == 64) {
605 // unrolling 8 times
606 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
607 float* op = &out[rangeIndex * block_size];
608 __m256 vop0 = _mm256_setzero_ps();
609 __m256 vop8 = _mm256_setzero_ps();
610 __m256 vop16 = _mm256_setzero_ps();
611 __m256 vop24 = _mm256_setzero_ps();
612 __m256 vop32 = _mm256_setzero_ps();
613 __m256 vop40 = _mm256_setzero_ps();
614 __m256 vop48 = _mm256_setzero_ps();
615 __m256 vop56 = _mm256_setzero_ps();
616 if (dataInd != offsets[rangeIndex] - offsets[0]) {
617 return false;
618 }
619 int64_t end_offset = offsets[rangeIndex + 1];
620 int64_t length = end_offset - offsets[rangeIndex];
621 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
622 ++dataInd) {
623 const int64_t idx = indices[dataInd];
624 if (idx < 0 || idx >= data_size) {
625 return false;
626 }
627 float wgt = 1.f;
628 if (weights) {
629 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
630 }
631 __m256 vwgt = _mm256_set1_ps(wgt);
632 const float* ip = &input[idx * fused_block_size];
633 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
634 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
635 ? (dataInd + prefdist_T0)
636 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
637 : dataInd;
638 const int64_t idx_pref_T0 = indices[next_T0];
639 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
640 return false;
641 }
642 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
643 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
644 _mm_prefetch(
645 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
646 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
647 // skip unnecessary prefetch of (&ip_next_T0[8])
648 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
649 _mm_prefetch(
650 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
651 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
652 // skip unnecessary prefetch of (&ip_next_T0[24])
653 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
654 _mm_prefetch(
655 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
656 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
657 // skip unnecessary prefetch of (&ip_next_T0[40])
658 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
659 _mm_prefetch(
660 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
661 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
662 // skip unnecessary prefetch of (&ip_next_T0[56])
663 }
664 if (!normalize_by_lengths || length == 0) {
665 _mm256_storeu_ps(&op[0], vop0);
666 _mm256_storeu_ps(&op[8], vop8);
667 _mm256_storeu_ps(&op[16], vop16);
668 _mm256_storeu_ps(&op[24], vop24);
669 _mm256_storeu_ps(&op[32], vop32);
670 _mm256_storeu_ps(&op[40], vop40);
671 _mm256_storeu_ps(&op[48], vop48);
672 _mm256_storeu_ps(&op[56], vop56);
673 } else {
674 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
675 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
676 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
677 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
678 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
679 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
680 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
681 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
682 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
683 }
684 }
685 } else if (block_size == 32) {
686 // unrolling 4 times
687 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
688 float* op = &out[rangeIndex * block_size];
689 __m256 vop0 = _mm256_setzero_ps();
690 __m256 vop8 = _mm256_setzero_ps();
691 __m256 vop16 = _mm256_setzero_ps();
692 __m256 vop24 = _mm256_setzero_ps();
693 if (dataInd != offsets[rangeIndex] - offsets[0]) {
694 return false;
695 }
696 int64_t end_offset = offsets[rangeIndex + 1];
697 int64_t length = end_offset - offsets[rangeIndex];
698 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
699 ++dataInd) {
700 const int64_t idx = indices[dataInd];
701 if (idx < 0 || idx >= data_size) {
702 return false;
703 }
704 float wgt = 1.f;
705 if (weights) {
706 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
707 }
708 __m256 vwgt = _mm256_set1_ps(wgt);
709 const float* ip = &input[idx * fused_block_size];
710 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
711 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
712 ? (dataInd + prefdist_T0)
713 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
714 : dataInd;
715 const int64_t idx_pref_T0 = indices[next_T0];
716 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
717 return false;
718 }
719 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
720 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
721 _mm_prefetch(
722 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
723 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
724 // skip unnecessary prefetch of (&ip_next_T0[8])
725 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
726 _mm_prefetch(
727 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
728 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
729 // skip unnecessary prefetch of (&ip_next_T0[24])
730 }
731 if (!normalize_by_lengths || length == 0) {
732 _mm256_storeu_ps(&op[0], vop0);
733 _mm256_storeu_ps(&op[8], vop8);
734 _mm256_storeu_ps(&op[16], vop16);
735 _mm256_storeu_ps(&op[24], vop24);
736 } else {
737 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
738 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
739 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
740 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
741 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
742 }
743 }
744 } else if (block_size == 16) {
745 // unrolling 2 times
746 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
747 float* op = &out[rangeIndex * block_size];
748 __m256 vop0 = _mm256_setzero_ps();
749 __m256 vop8 = _mm256_setzero_ps();
750 if (dataInd != offsets[rangeIndex] - offsets[0]) {
751 return false;
752 }
753 int64_t end_offset = offsets[rangeIndex + 1];
754 int64_t length = end_offset - offsets[rangeIndex];
755 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
756 ++dataInd) {
757 const int64_t idx = indices[dataInd];
758 if (idx < 0 || idx >= data_size) {
759 return false;
760 }
761 float wgt = 1.f;
762 if (weights) {
763 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
764 }
765 __m256 vwgt = _mm256_set1_ps(wgt);
766 const float* ip = &input[idx * fused_block_size];
767 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
768 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
769 ? (dataInd + prefdist_T0)
770 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
771 : dataInd;
772 const int64_t idx_pref_T0 = indices[next_T0];
773 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
774 return false;
775 }
776 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
777 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
778 _mm_prefetch(
779 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
780 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
781 // skip unnecessary prefetch of (&ip_next_T0[8])
782 }
783 if (!normalize_by_lengths || length == 0) {
784 _mm256_storeu_ps(&op[0], vop0);
785 _mm256_storeu_ps(&op[8], vop8);
786 } else {
787 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
788 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
789 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
790 }
791 }
792 } else {
793 // generic code
794 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
795 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
796 float* op = &out[rangeIndex * block_size];
797 int64_t j = 0;
798 for (; j + 8 <= block_size; j += 8) {
799 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
800 }
801 for (; j < block_size; j++) {
802 op[j] = 0.0f;
803 }
804 if (dataInd != offsets[rangeIndex] - offsets[0]) {
805 return false;
806 }
807 int64_t end_offset = offsets[rangeIndex + 1];
808 int64_t length = end_offset - offsets[rangeIndex];
809 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
810 ++dataInd) {
811 const int64_t idx = indices[dataInd];
812 if (idx < 0 || idx >= data_size) {
813 return false;
814 }
815 float wgt = 1.f;
816 if (weights) {
817 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
818 }
819 __m256 vwgt = _mm256_set1_ps(wgt);
820 const float* ip = &input[idx * fused_block_size];
821 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
822 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
823 ? (dataInd + prefdist_T0)
824 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
825 : dataInd;
826 const int64_t idx_pref_T0 = indices[next_T0];
827 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
828 return false;
829 }
830 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
831 j = 0;
832 for (; j + 8 <= block_size; j += 8) {
833 _mm256_storeu_ps(
834 &op[j],
835 _mm256_fmadd_ps(
836 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
837 _mm_prefetch(
838 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
839 }
840 for (; j < block_size; j++) {
841 op[j] = std::fma(wgt, ip[j], op[j]);
842 }
843 }
844 if (normalize_by_lengths && length) {
845 float len_inv = 1.0f / length;
846 __m256 vlen_inv = _mm256_set1_ps(len_inv);
847 j = 0;
848 for (; j + 8 <= block_size; j += 8) {
849 _mm256_storeu_ps(
850 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
851 }
852 for (; j < block_size; j++) {
853 op[j] = len_inv * op[j];
854 }
855 }
856 }
857 }
858 return dataInd == index_size;
859 }
EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)860 bool EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(
861 const int64_t block_size,
862 const int64_t output_size,
863 const int64_t index_size,
864 const int64_t data_size,
865 const float* input,
866 const int64_t* indices,
867 const int64_t* offsets,
868 const float* weights,
869 const float* scale_bias,
870 bool normalize_by_lengths,
871 float* out) {
872 return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<false>(
873 block_size,
874 output_size,
875 index_size,
876 data_size,
877 input,
878 indices,
879 offsets,
880 weights,
881 scale_bias,
882 normalize_by_lengths,
883 out);
884 }
EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)885 bool EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(
886 const int64_t block_size,
887 const int64_t output_size,
888 const int64_t index_size,
889 const int64_t data_size,
890 const float* input,
891 const int64_t* indices,
892 const int64_t* offsets,
893 const float* weights,
894 const float* scale_bias,
895 bool normalize_by_lengths,
896 float* out) {
897 return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<true>(
898 block_size,
899 output_size,
900 index_size,
901 data_size,
902 input,
903 indices,
904 offsets,
905 weights,
906 scale_bias,
907 normalize_by_lengths,
908 out);
909 }
910
911 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_half_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)912 static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(
913 const int64_t block_size,
914 const int64_t output_size,
915 const int64_t index_size,
916 const int64_t data_size,
917 const at::Half* input,
918 const int* indices,
919 const int* offsets,
920 const float* weights,
921 const float* scale_bias,
922 bool normalize_by_lengths,
923 float* out) {
924 const int prefdist_T0 = 16;
925 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
926 const int fused_block_size = block_size + 0;
927 int64_t dataInd = 0;
928 if (block_size == 128) {
929 // unrolling 16 times
930 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
931 float* op = &out[rangeIndex * block_size];
932 __m256 vop0 = _mm256_setzero_ps();
933 __m256 vop8 = _mm256_setzero_ps();
934 __m256 vop16 = _mm256_setzero_ps();
935 __m256 vop24 = _mm256_setzero_ps();
936 __m256 vop32 = _mm256_setzero_ps();
937 __m256 vop40 = _mm256_setzero_ps();
938 __m256 vop48 = _mm256_setzero_ps();
939 __m256 vop56 = _mm256_setzero_ps();
940 __m256 vop64 = _mm256_setzero_ps();
941 __m256 vop72 = _mm256_setzero_ps();
942 __m256 vop80 = _mm256_setzero_ps();
943 __m256 vop88 = _mm256_setzero_ps();
944 __m256 vop96 = _mm256_setzero_ps();
945 __m256 vop104 = _mm256_setzero_ps();
946 __m256 vop112 = _mm256_setzero_ps();
947 __m256 vop120 = _mm256_setzero_ps();
948 if (dataInd != offsets[rangeIndex] - offsets[0]) {
949 return false;
950 }
951 int64_t end_offset = offsets[rangeIndex + 1];
952 int64_t length = end_offset - offsets[rangeIndex];
953 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
954 ++dataInd) {
955 const int idx = indices[dataInd];
956 if (idx < 0 || idx >= data_size) {
957 return false;
958 }
959 float wgt = 1.f;
960 if (weights) {
961 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
962 }
963 __m256 vwgt = _mm256_set1_ps(wgt);
964 const at::Half* ip = &input[idx * fused_block_size];
965 const int next_T0 = (dataInd < index_size - prefdist_T0)
966 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
967 ? (dataInd + prefdist_T0)
968 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
969 : dataInd;
970 const int idx_pref_T0 = indices[next_T0];
971 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
972 return false;
973 }
974 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
975 vop0 = _mm256_fmadd_ps(
976 vwgt,
977 _mm256_cvtph_ps(
978 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
979 vop0);
980 _mm_prefetch(
981 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
982 vop8 = _mm256_fmadd_ps(
983 vwgt,
984 _mm256_cvtph_ps(
985 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
986 vop8);
987 // skip unnecessary prefetch of (&ip_next_T0[8])
988 vop16 = _mm256_fmadd_ps(
989 vwgt,
990 _mm256_cvtph_ps(
991 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
992 vop16);
993 // skip unnecessary prefetch of (&ip_next_T0[16])
994 vop24 = _mm256_fmadd_ps(
995 vwgt,
996 _mm256_cvtph_ps(
997 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
998 vop24);
999 // skip unnecessary prefetch of (&ip_next_T0[24])
1000 vop32 = _mm256_fmadd_ps(
1001 vwgt,
1002 _mm256_cvtph_ps(
1003 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1004 vop32);
1005 _mm_prefetch(
1006 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1007 vop40 = _mm256_fmadd_ps(
1008 vwgt,
1009 _mm256_cvtph_ps(
1010 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1011 vop40);
1012 // skip unnecessary prefetch of (&ip_next_T0[40])
1013 vop48 = _mm256_fmadd_ps(
1014 vwgt,
1015 _mm256_cvtph_ps(
1016 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1017 vop48);
1018 // skip unnecessary prefetch of (&ip_next_T0[48])
1019 vop56 = _mm256_fmadd_ps(
1020 vwgt,
1021 _mm256_cvtph_ps(
1022 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1023 vop56);
1024 // skip unnecessary prefetch of (&ip_next_T0[56])
1025 vop64 = _mm256_fmadd_ps(
1026 vwgt,
1027 _mm256_cvtph_ps(
1028 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1029 vop64);
1030 _mm_prefetch(
1031 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1032 vop72 = _mm256_fmadd_ps(
1033 vwgt,
1034 _mm256_cvtph_ps(
1035 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1036 vop72);
1037 // skip unnecessary prefetch of (&ip_next_T0[72])
1038 vop80 = _mm256_fmadd_ps(
1039 vwgt,
1040 _mm256_cvtph_ps(
1041 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1042 vop80);
1043 // skip unnecessary prefetch of (&ip_next_T0[80])
1044 vop88 = _mm256_fmadd_ps(
1045 vwgt,
1046 _mm256_cvtph_ps(
1047 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1048 vop88);
1049 // skip unnecessary prefetch of (&ip_next_T0[88])
1050 vop96 = _mm256_fmadd_ps(
1051 vwgt,
1052 _mm256_cvtph_ps(
1053 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1054 vop96);
1055 _mm_prefetch(
1056 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1057 vop104 = _mm256_fmadd_ps(
1058 vwgt,
1059 _mm256_cvtph_ps(
1060 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1061 vop104);
1062 // skip unnecessary prefetch of (&ip_next_T0[104])
1063 vop112 = _mm256_fmadd_ps(
1064 vwgt,
1065 _mm256_cvtph_ps(
1066 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1067 vop112);
1068 // skip unnecessary prefetch of (&ip_next_T0[112])
1069 vop120 = _mm256_fmadd_ps(
1070 vwgt,
1071 _mm256_cvtph_ps(
1072 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1073 vop120);
1074 // skip unnecessary prefetch of (&ip_next_T0[120])
1075 }
1076 if (!normalize_by_lengths || length == 0) {
1077 _mm256_storeu_ps(&op[0], vop0);
1078 _mm256_storeu_ps(&op[8], vop8);
1079 _mm256_storeu_ps(&op[16], vop16);
1080 _mm256_storeu_ps(&op[24], vop24);
1081 _mm256_storeu_ps(&op[32], vop32);
1082 _mm256_storeu_ps(&op[40], vop40);
1083 _mm256_storeu_ps(&op[48], vop48);
1084 _mm256_storeu_ps(&op[56], vop56);
1085 _mm256_storeu_ps(&op[64], vop64);
1086 _mm256_storeu_ps(&op[72], vop72);
1087 _mm256_storeu_ps(&op[80], vop80);
1088 _mm256_storeu_ps(&op[88], vop88);
1089 _mm256_storeu_ps(&op[96], vop96);
1090 _mm256_storeu_ps(&op[104], vop104);
1091 _mm256_storeu_ps(&op[112], vop112);
1092 _mm256_storeu_ps(&op[120], vop120);
1093 } else {
1094 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1095 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1096 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1097 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1098 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1099 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1100 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1101 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1102 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1103 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1104 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1105 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1106 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1107 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1108 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1109 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1110 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1111 }
1112 }
1113 } else if (block_size == 64) {
1114 // unrolling 8 times
1115 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1116 float* op = &out[rangeIndex * block_size];
1117 __m256 vop0 = _mm256_setzero_ps();
1118 __m256 vop8 = _mm256_setzero_ps();
1119 __m256 vop16 = _mm256_setzero_ps();
1120 __m256 vop24 = _mm256_setzero_ps();
1121 __m256 vop32 = _mm256_setzero_ps();
1122 __m256 vop40 = _mm256_setzero_ps();
1123 __m256 vop48 = _mm256_setzero_ps();
1124 __m256 vop56 = _mm256_setzero_ps();
1125 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1126 return false;
1127 }
1128 int64_t end_offset = offsets[rangeIndex + 1];
1129 int64_t length = end_offset - offsets[rangeIndex];
1130 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1131 ++dataInd) {
1132 const int idx = indices[dataInd];
1133 if (idx < 0 || idx >= data_size) {
1134 return false;
1135 }
1136 float wgt = 1.f;
1137 if (weights) {
1138 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1139 }
1140 __m256 vwgt = _mm256_set1_ps(wgt);
1141 const at::Half* ip = &input[idx * fused_block_size];
1142 const int next_T0 = (dataInd < index_size - prefdist_T0)
1143 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1144 ? (dataInd + prefdist_T0)
1145 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1146 : dataInd;
1147 const int idx_pref_T0 = indices[next_T0];
1148 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1149 return false;
1150 }
1151 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1152 vop0 = _mm256_fmadd_ps(
1153 vwgt,
1154 _mm256_cvtph_ps(
1155 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1156 vop0);
1157 _mm_prefetch(
1158 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1159 vop8 = _mm256_fmadd_ps(
1160 vwgt,
1161 _mm256_cvtph_ps(
1162 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1163 vop8);
1164 // skip unnecessary prefetch of (&ip_next_T0[8])
1165 vop16 = _mm256_fmadd_ps(
1166 vwgt,
1167 _mm256_cvtph_ps(
1168 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1169 vop16);
1170 // skip unnecessary prefetch of (&ip_next_T0[16])
1171 vop24 = _mm256_fmadd_ps(
1172 vwgt,
1173 _mm256_cvtph_ps(
1174 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1175 vop24);
1176 // skip unnecessary prefetch of (&ip_next_T0[24])
1177 vop32 = _mm256_fmadd_ps(
1178 vwgt,
1179 _mm256_cvtph_ps(
1180 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1181 vop32);
1182 _mm_prefetch(
1183 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1184 vop40 = _mm256_fmadd_ps(
1185 vwgt,
1186 _mm256_cvtph_ps(
1187 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1188 vop40);
1189 // skip unnecessary prefetch of (&ip_next_T0[40])
1190 vop48 = _mm256_fmadd_ps(
1191 vwgt,
1192 _mm256_cvtph_ps(
1193 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1194 vop48);
1195 // skip unnecessary prefetch of (&ip_next_T0[48])
1196 vop56 = _mm256_fmadd_ps(
1197 vwgt,
1198 _mm256_cvtph_ps(
1199 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1200 vop56);
1201 // skip unnecessary prefetch of (&ip_next_T0[56])
1202 }
1203 if (!normalize_by_lengths || length == 0) {
1204 _mm256_storeu_ps(&op[0], vop0);
1205 _mm256_storeu_ps(&op[8], vop8);
1206 _mm256_storeu_ps(&op[16], vop16);
1207 _mm256_storeu_ps(&op[24], vop24);
1208 _mm256_storeu_ps(&op[32], vop32);
1209 _mm256_storeu_ps(&op[40], vop40);
1210 _mm256_storeu_ps(&op[48], vop48);
1211 _mm256_storeu_ps(&op[56], vop56);
1212 } else {
1213 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1214 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1215 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1216 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1217 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1218 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1219 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1220 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1221 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1222 }
1223 }
1224 } else if (block_size == 32) {
1225 // unrolling 4 times
1226 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1227 float* op = &out[rangeIndex * block_size];
1228 __m256 vop0 = _mm256_setzero_ps();
1229 __m256 vop8 = _mm256_setzero_ps();
1230 __m256 vop16 = _mm256_setzero_ps();
1231 __m256 vop24 = _mm256_setzero_ps();
1232 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1233 return false;
1234 }
1235 int64_t end_offset = offsets[rangeIndex + 1];
1236 int64_t length = end_offset - offsets[rangeIndex];
1237 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1238 ++dataInd) {
1239 const int idx = indices[dataInd];
1240 if (idx < 0 || idx >= data_size) {
1241 return false;
1242 }
1243 float wgt = 1.f;
1244 if (weights) {
1245 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1246 }
1247 __m256 vwgt = _mm256_set1_ps(wgt);
1248 const at::Half* ip = &input[idx * fused_block_size];
1249 const int next_T0 = (dataInd < index_size - prefdist_T0)
1250 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1251 ? (dataInd + prefdist_T0)
1252 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1253 : dataInd;
1254 const int idx_pref_T0 = indices[next_T0];
1255 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1256 return false;
1257 }
1258 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1259 vop0 = _mm256_fmadd_ps(
1260 vwgt,
1261 _mm256_cvtph_ps(
1262 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1263 vop0);
1264 _mm_prefetch(
1265 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1266 vop8 = _mm256_fmadd_ps(
1267 vwgt,
1268 _mm256_cvtph_ps(
1269 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1270 vop8);
1271 // skip unnecessary prefetch of (&ip_next_T0[8])
1272 vop16 = _mm256_fmadd_ps(
1273 vwgt,
1274 _mm256_cvtph_ps(
1275 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1276 vop16);
1277 // skip unnecessary prefetch of (&ip_next_T0[16])
1278 vop24 = _mm256_fmadd_ps(
1279 vwgt,
1280 _mm256_cvtph_ps(
1281 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1282 vop24);
1283 // skip unnecessary prefetch of (&ip_next_T0[24])
1284 }
1285 if (!normalize_by_lengths || length == 0) {
1286 _mm256_storeu_ps(&op[0], vop0);
1287 _mm256_storeu_ps(&op[8], vop8);
1288 _mm256_storeu_ps(&op[16], vop16);
1289 _mm256_storeu_ps(&op[24], vop24);
1290 } else {
1291 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1292 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1293 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1294 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1295 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1296 }
1297 }
1298 } else if (block_size == 16) {
1299 // unrolling 2 times
1300 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1301 float* op = &out[rangeIndex * block_size];
1302 __m256 vop0 = _mm256_setzero_ps();
1303 __m256 vop8 = _mm256_setzero_ps();
1304 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1305 return false;
1306 }
1307 int64_t end_offset = offsets[rangeIndex + 1];
1308 int64_t length = end_offset - offsets[rangeIndex];
1309 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1310 ++dataInd) {
1311 const int idx = indices[dataInd];
1312 if (idx < 0 || idx >= data_size) {
1313 return false;
1314 }
1315 float wgt = 1.f;
1316 if (weights) {
1317 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1318 }
1319 __m256 vwgt = _mm256_set1_ps(wgt);
1320 const at::Half* ip = &input[idx * fused_block_size];
1321 const int next_T0 = (dataInd < index_size - prefdist_T0)
1322 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1323 ? (dataInd + prefdist_T0)
1324 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1325 : dataInd;
1326 const int idx_pref_T0 = indices[next_T0];
1327 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1328 return false;
1329 }
1330 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1331 vop0 = _mm256_fmadd_ps(
1332 vwgt,
1333 _mm256_cvtph_ps(
1334 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1335 vop0);
1336 _mm_prefetch(
1337 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1338 vop8 = _mm256_fmadd_ps(
1339 vwgt,
1340 _mm256_cvtph_ps(
1341 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1342 vop8);
1343 // skip unnecessary prefetch of (&ip_next_T0[8])
1344 }
1345 if (!normalize_by_lengths || length == 0) {
1346 _mm256_storeu_ps(&op[0], vop0);
1347 _mm256_storeu_ps(&op[8], vop8);
1348 } else {
1349 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1350 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1351 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1352 }
1353 }
1354 } else {
1355 // generic code
1356 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1357 alignas(64) at::Half vtmp1[8] = {0};
1358 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1359 float* op = &out[rangeIndex * block_size];
1360 int64_t j = 0;
1361 for (; j + 8 <= block_size; j += 8) {
1362 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1363 }
1364 for (; j < block_size; j++) {
1365 op[j] = 0.0f;
1366 }
1367 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1368 return false;
1369 }
1370 int64_t end_offset = offsets[rangeIndex + 1];
1371 int64_t length = end_offset - offsets[rangeIndex];
1372 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1373 ++dataInd) {
1374 const int idx = indices[dataInd];
1375 if (idx < 0 || idx >= data_size) {
1376 return false;
1377 }
1378 float wgt = 1.f;
1379 if (weights) {
1380 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1381 }
1382 __m256 vwgt = _mm256_set1_ps(wgt);
1383 const at::Half* ip = &input[idx * fused_block_size];
1384 const int next_T0 = (dataInd < index_size - prefdist_T0)
1385 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1386 ? (dataInd + prefdist_T0)
1387 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1388 : dataInd;
1389 const int idx_pref_T0 = indices[next_T0];
1390 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1391 return false;
1392 }
1393 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1394 j = 0;
1395 for (; j + 8 <= block_size; j += 8) {
1396 _mm256_storeu_ps(
1397 &op[j],
1398 _mm256_fmadd_ps(
1399 vwgt,
1400 _mm256_cvtph_ps(_mm_loadu_si128(
1401 reinterpret_cast<const __m128i*>(&ip[j]))),
1402 _mm256_loadu_ps(&op[j])));
1403 _mm_prefetch(
1404 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1405 }
1406 for (; j < block_size; j++) {
1407 vtmp1[0] = ip[j];
1408 __m256 vtmp2 =
1409 _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1410 op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1411 }
1412 }
1413 if (normalize_by_lengths && length) {
1414 float len_inv = 1.0f / length;
1415 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1416 j = 0;
1417 for (; j + 8 <= block_size; j += 8) {
1418 _mm256_storeu_ps(
1419 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1420 }
1421 for (; j < block_size; j++) {
1422 op[j] = len_inv * op[j];
1423 }
1424 }
1425 }
1426 }
1427 return dataInd == index_size;
1428 }
EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1429 bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
1430 const int64_t block_size,
1431 const int64_t output_size,
1432 const int64_t index_size,
1433 const int64_t data_size,
1434 const at::Half* input,
1435 const int* indices,
1436 const int* offsets,
1437 const float* weights,
1438 const float* scale_bias,
1439 bool normalize_by_lengths,
1440 float* out) {
1441 return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<false>(
1442 block_size,
1443 output_size,
1444 index_size,
1445 data_size,
1446 input,
1447 indices,
1448 offsets,
1449 weights,
1450 scale_bias,
1451 normalize_by_lengths,
1452 out);
1453 }
EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1454 bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
1455 const int64_t block_size,
1456 const int64_t output_size,
1457 const int64_t index_size,
1458 const int64_t data_size,
1459 const at::Half* input,
1460 const int* indices,
1461 const int* offsets,
1462 const float* weights,
1463 const float* scale_bias,
1464 bool normalize_by_lengths,
1465 float* out) {
1466 return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<true>(
1467 block_size,
1468 output_size,
1469 index_size,
1470 data_size,
1471 input,
1472 indices,
1473 offsets,
1474 weights,
1475 scale_bias,
1476 normalize_by_lengths,
1477 out);
1478 }
1479
1480 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_half_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1481 static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma(
1482 const int64_t block_size,
1483 const int64_t output_size,
1484 const int64_t index_size,
1485 const int64_t data_size,
1486 const at::Half* input,
1487 const int64_t* indices,
1488 const int64_t* offsets,
1489 const float* weights,
1490 const float* scale_bias,
1491 bool normalize_by_lengths,
1492 float* out) {
1493 const int64_t prefdist_T0 = 16;
1494 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1495 const int64_t fused_block_size = block_size + 0;
1496 int64_t dataInd = 0;
1497 if (block_size == 128) {
1498 // unrolling 16 times
1499 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1500 float* op = &out[rangeIndex * block_size];
1501 __m256 vop0 = _mm256_setzero_ps();
1502 __m256 vop8 = _mm256_setzero_ps();
1503 __m256 vop16 = _mm256_setzero_ps();
1504 __m256 vop24 = _mm256_setzero_ps();
1505 __m256 vop32 = _mm256_setzero_ps();
1506 __m256 vop40 = _mm256_setzero_ps();
1507 __m256 vop48 = _mm256_setzero_ps();
1508 __m256 vop56 = _mm256_setzero_ps();
1509 __m256 vop64 = _mm256_setzero_ps();
1510 __m256 vop72 = _mm256_setzero_ps();
1511 __m256 vop80 = _mm256_setzero_ps();
1512 __m256 vop88 = _mm256_setzero_ps();
1513 __m256 vop96 = _mm256_setzero_ps();
1514 __m256 vop104 = _mm256_setzero_ps();
1515 __m256 vop112 = _mm256_setzero_ps();
1516 __m256 vop120 = _mm256_setzero_ps();
1517 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1518 return false;
1519 }
1520 int64_t end_offset = offsets[rangeIndex + 1];
1521 int64_t length = end_offset - offsets[rangeIndex];
1522 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1523 ++dataInd) {
1524 const int64_t idx = indices[dataInd];
1525 if (idx < 0 || idx >= data_size) {
1526 return false;
1527 }
1528 float wgt = 1.f;
1529 if (weights) {
1530 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1531 }
1532 __m256 vwgt = _mm256_set1_ps(wgt);
1533 const at::Half* ip = &input[idx * fused_block_size];
1534 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1535 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1536 ? (dataInd + prefdist_T0)
1537 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1538 : dataInd;
1539 const int64_t idx_pref_T0 = indices[next_T0];
1540 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1541 return false;
1542 }
1543 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1544 vop0 = _mm256_fmadd_ps(
1545 vwgt,
1546 _mm256_cvtph_ps(
1547 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1548 vop0);
1549 _mm_prefetch(
1550 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1551 vop8 = _mm256_fmadd_ps(
1552 vwgt,
1553 _mm256_cvtph_ps(
1554 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1555 vop8);
1556 // skip unnecessary prefetch of (&ip_next_T0[8])
1557 vop16 = _mm256_fmadd_ps(
1558 vwgt,
1559 _mm256_cvtph_ps(
1560 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1561 vop16);
1562 // skip unnecessary prefetch of (&ip_next_T0[16])
1563 vop24 = _mm256_fmadd_ps(
1564 vwgt,
1565 _mm256_cvtph_ps(
1566 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1567 vop24);
1568 // skip unnecessary prefetch of (&ip_next_T0[24])
1569 vop32 = _mm256_fmadd_ps(
1570 vwgt,
1571 _mm256_cvtph_ps(
1572 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1573 vop32);
1574 _mm_prefetch(
1575 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1576 vop40 = _mm256_fmadd_ps(
1577 vwgt,
1578 _mm256_cvtph_ps(
1579 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1580 vop40);
1581 // skip unnecessary prefetch of (&ip_next_T0[40])
1582 vop48 = _mm256_fmadd_ps(
1583 vwgt,
1584 _mm256_cvtph_ps(
1585 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1586 vop48);
1587 // skip unnecessary prefetch of (&ip_next_T0[48])
1588 vop56 = _mm256_fmadd_ps(
1589 vwgt,
1590 _mm256_cvtph_ps(
1591 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1592 vop56);
1593 // skip unnecessary prefetch of (&ip_next_T0[56])
1594 vop64 = _mm256_fmadd_ps(
1595 vwgt,
1596 _mm256_cvtph_ps(
1597 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1598 vop64);
1599 _mm_prefetch(
1600 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1601 vop72 = _mm256_fmadd_ps(
1602 vwgt,
1603 _mm256_cvtph_ps(
1604 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1605 vop72);
1606 // skip unnecessary prefetch of (&ip_next_T0[72])
1607 vop80 = _mm256_fmadd_ps(
1608 vwgt,
1609 _mm256_cvtph_ps(
1610 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1611 vop80);
1612 // skip unnecessary prefetch of (&ip_next_T0[80])
1613 vop88 = _mm256_fmadd_ps(
1614 vwgt,
1615 _mm256_cvtph_ps(
1616 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1617 vop88);
1618 // skip unnecessary prefetch of (&ip_next_T0[88])
1619 vop96 = _mm256_fmadd_ps(
1620 vwgt,
1621 _mm256_cvtph_ps(
1622 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1623 vop96);
1624 _mm_prefetch(
1625 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1626 vop104 = _mm256_fmadd_ps(
1627 vwgt,
1628 _mm256_cvtph_ps(
1629 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1630 vop104);
1631 // skip unnecessary prefetch of (&ip_next_T0[104])
1632 vop112 = _mm256_fmadd_ps(
1633 vwgt,
1634 _mm256_cvtph_ps(
1635 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1636 vop112);
1637 // skip unnecessary prefetch of (&ip_next_T0[112])
1638 vop120 = _mm256_fmadd_ps(
1639 vwgt,
1640 _mm256_cvtph_ps(
1641 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1642 vop120);
1643 // skip unnecessary prefetch of (&ip_next_T0[120])
1644 }
1645 if (!normalize_by_lengths || length == 0) {
1646 _mm256_storeu_ps(&op[0], vop0);
1647 _mm256_storeu_ps(&op[8], vop8);
1648 _mm256_storeu_ps(&op[16], vop16);
1649 _mm256_storeu_ps(&op[24], vop24);
1650 _mm256_storeu_ps(&op[32], vop32);
1651 _mm256_storeu_ps(&op[40], vop40);
1652 _mm256_storeu_ps(&op[48], vop48);
1653 _mm256_storeu_ps(&op[56], vop56);
1654 _mm256_storeu_ps(&op[64], vop64);
1655 _mm256_storeu_ps(&op[72], vop72);
1656 _mm256_storeu_ps(&op[80], vop80);
1657 _mm256_storeu_ps(&op[88], vop88);
1658 _mm256_storeu_ps(&op[96], vop96);
1659 _mm256_storeu_ps(&op[104], vop104);
1660 _mm256_storeu_ps(&op[112], vop112);
1661 _mm256_storeu_ps(&op[120], vop120);
1662 } else {
1663 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1664 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1665 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1666 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1667 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1668 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1669 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1670 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1671 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1672 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1673 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1674 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1675 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1676 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1677 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1678 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1679 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1680 }
1681 }
1682 } else if (block_size == 64) {
1683 // unrolling 8 times
1684 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1685 float* op = &out[rangeIndex * block_size];
1686 __m256 vop0 = _mm256_setzero_ps();
1687 __m256 vop8 = _mm256_setzero_ps();
1688 __m256 vop16 = _mm256_setzero_ps();
1689 __m256 vop24 = _mm256_setzero_ps();
1690 __m256 vop32 = _mm256_setzero_ps();
1691 __m256 vop40 = _mm256_setzero_ps();
1692 __m256 vop48 = _mm256_setzero_ps();
1693 __m256 vop56 = _mm256_setzero_ps();
1694 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1695 return false;
1696 }
1697 int64_t end_offset = offsets[rangeIndex + 1];
1698 int64_t length = end_offset - offsets[rangeIndex];
1699 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1700 ++dataInd) {
1701 const int64_t idx = indices[dataInd];
1702 if (idx < 0 || idx >= data_size) {
1703 return false;
1704 }
1705 float wgt = 1.f;
1706 if (weights) {
1707 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1708 }
1709 __m256 vwgt = _mm256_set1_ps(wgt);
1710 const at::Half* ip = &input[idx * fused_block_size];
1711 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1712 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1713 ? (dataInd + prefdist_T0)
1714 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1715 : dataInd;
1716 const int64_t idx_pref_T0 = indices[next_T0];
1717 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1718 return false;
1719 }
1720 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1721 vop0 = _mm256_fmadd_ps(
1722 vwgt,
1723 _mm256_cvtph_ps(
1724 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1725 vop0);
1726 _mm_prefetch(
1727 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1728 vop8 = _mm256_fmadd_ps(
1729 vwgt,
1730 _mm256_cvtph_ps(
1731 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1732 vop8);
1733 // skip unnecessary prefetch of (&ip_next_T0[8])
1734 vop16 = _mm256_fmadd_ps(
1735 vwgt,
1736 _mm256_cvtph_ps(
1737 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1738 vop16);
1739 // skip unnecessary prefetch of (&ip_next_T0[16])
1740 vop24 = _mm256_fmadd_ps(
1741 vwgt,
1742 _mm256_cvtph_ps(
1743 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1744 vop24);
1745 // skip unnecessary prefetch of (&ip_next_T0[24])
1746 vop32 = _mm256_fmadd_ps(
1747 vwgt,
1748 _mm256_cvtph_ps(
1749 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1750 vop32);
1751 _mm_prefetch(
1752 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1753 vop40 = _mm256_fmadd_ps(
1754 vwgt,
1755 _mm256_cvtph_ps(
1756 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1757 vop40);
1758 // skip unnecessary prefetch of (&ip_next_T0[40])
1759 vop48 = _mm256_fmadd_ps(
1760 vwgt,
1761 _mm256_cvtph_ps(
1762 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1763 vop48);
1764 // skip unnecessary prefetch of (&ip_next_T0[48])
1765 vop56 = _mm256_fmadd_ps(
1766 vwgt,
1767 _mm256_cvtph_ps(
1768 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1769 vop56);
1770 // skip unnecessary prefetch of (&ip_next_T0[56])
1771 }
1772 if (!normalize_by_lengths || length == 0) {
1773 _mm256_storeu_ps(&op[0], vop0);
1774 _mm256_storeu_ps(&op[8], vop8);
1775 _mm256_storeu_ps(&op[16], vop16);
1776 _mm256_storeu_ps(&op[24], vop24);
1777 _mm256_storeu_ps(&op[32], vop32);
1778 _mm256_storeu_ps(&op[40], vop40);
1779 _mm256_storeu_ps(&op[48], vop48);
1780 _mm256_storeu_ps(&op[56], vop56);
1781 } else {
1782 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1783 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1784 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1785 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1786 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1787 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1788 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1789 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1790 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1791 }
1792 }
1793 } else if (block_size == 32) {
1794 // unrolling 4 times
1795 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1796 float* op = &out[rangeIndex * block_size];
1797 __m256 vop0 = _mm256_setzero_ps();
1798 __m256 vop8 = _mm256_setzero_ps();
1799 __m256 vop16 = _mm256_setzero_ps();
1800 __m256 vop24 = _mm256_setzero_ps();
1801 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1802 return false;
1803 }
1804 int64_t end_offset = offsets[rangeIndex + 1];
1805 int64_t length = end_offset - offsets[rangeIndex];
1806 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1807 ++dataInd) {
1808 const int64_t idx = indices[dataInd];
1809 if (idx < 0 || idx >= data_size) {
1810 return false;
1811 }
1812 float wgt = 1.f;
1813 if (weights) {
1814 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1815 }
1816 __m256 vwgt = _mm256_set1_ps(wgt);
1817 const at::Half* ip = &input[idx * fused_block_size];
1818 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1819 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1820 ? (dataInd + prefdist_T0)
1821 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1822 : dataInd;
1823 const int64_t idx_pref_T0 = indices[next_T0];
1824 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1825 return false;
1826 }
1827 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1828 vop0 = _mm256_fmadd_ps(
1829 vwgt,
1830 _mm256_cvtph_ps(
1831 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1832 vop0);
1833 _mm_prefetch(
1834 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1835 vop8 = _mm256_fmadd_ps(
1836 vwgt,
1837 _mm256_cvtph_ps(
1838 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1839 vop8);
1840 // skip unnecessary prefetch of (&ip_next_T0[8])
1841 vop16 = _mm256_fmadd_ps(
1842 vwgt,
1843 _mm256_cvtph_ps(
1844 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1845 vop16);
1846 // skip unnecessary prefetch of (&ip_next_T0[16])
1847 vop24 = _mm256_fmadd_ps(
1848 vwgt,
1849 _mm256_cvtph_ps(
1850 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1851 vop24);
1852 // skip unnecessary prefetch of (&ip_next_T0[24])
1853 }
1854 if (!normalize_by_lengths || length == 0) {
1855 _mm256_storeu_ps(&op[0], vop0);
1856 _mm256_storeu_ps(&op[8], vop8);
1857 _mm256_storeu_ps(&op[16], vop16);
1858 _mm256_storeu_ps(&op[24], vop24);
1859 } else {
1860 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1861 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1862 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1863 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1864 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1865 }
1866 }
1867 } else if (block_size == 16) {
1868 // unrolling 2 times
1869 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1870 float* op = &out[rangeIndex * block_size];
1871 __m256 vop0 = _mm256_setzero_ps();
1872 __m256 vop8 = _mm256_setzero_ps();
1873 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1874 return false;
1875 }
1876 int64_t end_offset = offsets[rangeIndex + 1];
1877 int64_t length = end_offset - offsets[rangeIndex];
1878 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1879 ++dataInd) {
1880 const int64_t idx = indices[dataInd];
1881 if (idx < 0 || idx >= data_size) {
1882 return false;
1883 }
1884 float wgt = 1.f;
1885 if (weights) {
1886 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1887 }
1888 __m256 vwgt = _mm256_set1_ps(wgt);
1889 const at::Half* ip = &input[idx * fused_block_size];
1890 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1891 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1892 ? (dataInd + prefdist_T0)
1893 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1894 : dataInd;
1895 const int64_t idx_pref_T0 = indices[next_T0];
1896 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1897 return false;
1898 }
1899 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1900 vop0 = _mm256_fmadd_ps(
1901 vwgt,
1902 _mm256_cvtph_ps(
1903 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1904 vop0);
1905 _mm_prefetch(
1906 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1907 vop8 = _mm256_fmadd_ps(
1908 vwgt,
1909 _mm256_cvtph_ps(
1910 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1911 vop8);
1912 // skip unnecessary prefetch of (&ip_next_T0[8])
1913 }
1914 if (!normalize_by_lengths || length == 0) {
1915 _mm256_storeu_ps(&op[0], vop0);
1916 _mm256_storeu_ps(&op[8], vop8);
1917 } else {
1918 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1919 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1920 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1921 }
1922 }
1923 } else {
1924 // generic code
1925 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1926 alignas(64) at::Half vtmp1[8] = {0};
1927 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1928 float* op = &out[rangeIndex * block_size];
1929 int64_t j = 0;
1930 for (; j + 8 <= block_size; j += 8) {
1931 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1932 }
1933 for (; j < block_size; j++) {
1934 op[j] = 0.0f;
1935 }
1936 if (dataInd != offsets[rangeIndex] - offsets[0]) {
1937 return false;
1938 }
1939 int64_t end_offset = offsets[rangeIndex + 1];
1940 int64_t length = end_offset - offsets[rangeIndex];
1941 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1942 ++dataInd) {
1943 const int64_t idx = indices[dataInd];
1944 if (idx < 0 || idx >= data_size) {
1945 return false;
1946 }
1947 float wgt = 1.f;
1948 if (weights) {
1949 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1950 }
1951 __m256 vwgt = _mm256_set1_ps(wgt);
1952 const at::Half* ip = &input[idx * fused_block_size];
1953 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1954 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1955 ? (dataInd + prefdist_T0)
1956 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1957 : dataInd;
1958 const int64_t idx_pref_T0 = indices[next_T0];
1959 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1960 return false;
1961 }
1962 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1963 j = 0;
1964 for (; j + 8 <= block_size; j += 8) {
1965 _mm256_storeu_ps(
1966 &op[j],
1967 _mm256_fmadd_ps(
1968 vwgt,
1969 _mm256_cvtph_ps(_mm_loadu_si128(
1970 reinterpret_cast<const __m128i*>(&ip[j]))),
1971 _mm256_loadu_ps(&op[j])));
1972 _mm_prefetch(
1973 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1974 }
1975 for (; j < block_size; j++) {
1976 vtmp1[0] = ip[j];
1977 __m256 vtmp2 =
1978 _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1979 op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1980 }
1981 }
1982 if (normalize_by_lengths && length) {
1983 float len_inv = 1.0f / length;
1984 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1985 j = 0;
1986 for (; j + 8 <= block_size; j += 8) {
1987 _mm256_storeu_ps(
1988 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1989 }
1990 for (; j < block_size; j++) {
1991 op[j] = len_inv * op[j];
1992 }
1993 }
1994 }
1995 }
1996 return dataInd == index_size;
1997 }
EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1998 bool EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(
1999 const int64_t block_size,
2000 const int64_t output_size,
2001 const int64_t index_size,
2002 const int64_t data_size,
2003 const at::Half* input,
2004 const int64_t* indices,
2005 const int64_t* offsets,
2006 const float* weights,
2007 const float* scale_bias,
2008 bool normalize_by_lengths,
2009 float* out) {
2010 return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<false>(
2011 block_size,
2012 output_size,
2013 index_size,
2014 data_size,
2015 input,
2016 indices,
2017 offsets,
2018 weights,
2019 scale_bias,
2020 normalize_by_lengths,
2021 out);
2022 }
EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2023 bool EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(
2024 const int64_t block_size,
2025 const int64_t output_size,
2026 const int64_t index_size,
2027 const int64_t data_size,
2028 const at::Half* input,
2029 const int64_t* indices,
2030 const int64_t* offsets,
2031 const float* weights,
2032 const float* scale_bias,
2033 bool normalize_by_lengths,
2034 float* out) {
2035 return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<true>(
2036 block_size,
2037 output_size,
2038 index_size,
2039 data_size,
2040 input,
2041 indices,
2042 offsets,
2043 weights,
2044 scale_bias,
2045 normalize_by_lengths,
2046 out);
2047 }
2048
2049 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2050 static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(
2051 const int64_t block_size,
2052 const int64_t output_size,
2053 const int64_t index_size,
2054 const int64_t data_size,
2055 const at::BFloat16* input,
2056 const int* indices,
2057 const int* offsets,
2058 const float* weights,
2059 const float* scale_bias,
2060 bool normalize_by_lengths,
2061 float* out) {
2062 const int prefdist_T0 = 16;
2063 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2064 const int fused_block_size = block_size + 0;
2065 int64_t dataInd = 0;
2066 if (block_size == 128) {
2067 // unrolling 16 times
2068 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2069 float* op = &out[rangeIndex * block_size];
2070 __m256 vop0 = _mm256_setzero_ps();
2071 __m256 vop8 = _mm256_setzero_ps();
2072 __m256 vop16 = _mm256_setzero_ps();
2073 __m256 vop24 = _mm256_setzero_ps();
2074 __m256 vop32 = _mm256_setzero_ps();
2075 __m256 vop40 = _mm256_setzero_ps();
2076 __m256 vop48 = _mm256_setzero_ps();
2077 __m256 vop56 = _mm256_setzero_ps();
2078 __m256 vop64 = _mm256_setzero_ps();
2079 __m256 vop72 = _mm256_setzero_ps();
2080 __m256 vop80 = _mm256_setzero_ps();
2081 __m256 vop88 = _mm256_setzero_ps();
2082 __m256 vop96 = _mm256_setzero_ps();
2083 __m256 vop104 = _mm256_setzero_ps();
2084 __m256 vop112 = _mm256_setzero_ps();
2085 __m256 vop120 = _mm256_setzero_ps();
2086 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2087 return false;
2088 }
2089 int64_t end_offset = offsets[rangeIndex + 1];
2090 int64_t length = end_offset - offsets[rangeIndex];
2091 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2092 ++dataInd) {
2093 const int idx = indices[dataInd];
2094 if (idx < 0 || idx >= data_size) {
2095 return false;
2096 }
2097 float wgt = 1.f;
2098 if (weights) {
2099 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2100 }
2101 __m256 vwgt = _mm256_set1_ps(wgt);
2102 const at::BFloat16* ip = &input[idx * fused_block_size];
2103 const int next_T0 = (dataInd < index_size - prefdist_T0)
2104 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2105 ? (dataInd + prefdist_T0)
2106 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2107 : dataInd;
2108 const int idx_pref_T0 = indices[next_T0];
2109 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2110 return false;
2111 }
2112 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2113 vop0 = _mm256_fmadd_ps(
2114 vwgt,
2115 _mm256_castsi256_ps(_mm256_slli_epi32(
2116 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2117 reinterpret_cast<const __m128i*>(ip + (0)))),
2118 16)),
2119 vop0);
2120 _mm_prefetch(
2121 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2122 vop8 = _mm256_fmadd_ps(
2123 vwgt,
2124 _mm256_castsi256_ps(_mm256_slli_epi32(
2125 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2126 reinterpret_cast<const __m128i*>(ip + (8)))),
2127 16)),
2128 vop8);
2129 // skip unnecessary prefetch of (&ip_next_T0[8])
2130 vop16 = _mm256_fmadd_ps(
2131 vwgt,
2132 _mm256_castsi256_ps(_mm256_slli_epi32(
2133 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2134 reinterpret_cast<const __m128i*>(ip + (16)))),
2135 16)),
2136 vop16);
2137 // skip unnecessary prefetch of (&ip_next_T0[16])
2138 vop24 = _mm256_fmadd_ps(
2139 vwgt,
2140 _mm256_castsi256_ps(_mm256_slli_epi32(
2141 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2142 reinterpret_cast<const __m128i*>(ip + (24)))),
2143 16)),
2144 vop24);
2145 // skip unnecessary prefetch of (&ip_next_T0[24])
2146 vop32 = _mm256_fmadd_ps(
2147 vwgt,
2148 _mm256_castsi256_ps(_mm256_slli_epi32(
2149 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2150 reinterpret_cast<const __m128i*>(ip + (32)))),
2151 16)),
2152 vop32);
2153 _mm_prefetch(
2154 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2155 vop40 = _mm256_fmadd_ps(
2156 vwgt,
2157 _mm256_castsi256_ps(_mm256_slli_epi32(
2158 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2159 reinterpret_cast<const __m128i*>(ip + (40)))),
2160 16)),
2161 vop40);
2162 // skip unnecessary prefetch of (&ip_next_T0[40])
2163 vop48 = _mm256_fmadd_ps(
2164 vwgt,
2165 _mm256_castsi256_ps(_mm256_slli_epi32(
2166 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2167 reinterpret_cast<const __m128i*>(ip + (48)))),
2168 16)),
2169 vop48);
2170 // skip unnecessary prefetch of (&ip_next_T0[48])
2171 vop56 = _mm256_fmadd_ps(
2172 vwgt,
2173 _mm256_castsi256_ps(_mm256_slli_epi32(
2174 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2175 reinterpret_cast<const __m128i*>(ip + (56)))),
2176 16)),
2177 vop56);
2178 // skip unnecessary prefetch of (&ip_next_T0[56])
2179 vop64 = _mm256_fmadd_ps(
2180 vwgt,
2181 _mm256_castsi256_ps(_mm256_slli_epi32(
2182 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2183 reinterpret_cast<const __m128i*>(ip + (64)))),
2184 16)),
2185 vop64);
2186 _mm_prefetch(
2187 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2188 vop72 = _mm256_fmadd_ps(
2189 vwgt,
2190 _mm256_castsi256_ps(_mm256_slli_epi32(
2191 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2192 reinterpret_cast<const __m128i*>(ip + (72)))),
2193 16)),
2194 vop72);
2195 // skip unnecessary prefetch of (&ip_next_T0[72])
2196 vop80 = _mm256_fmadd_ps(
2197 vwgt,
2198 _mm256_castsi256_ps(_mm256_slli_epi32(
2199 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2200 reinterpret_cast<const __m128i*>(ip + (80)))),
2201 16)),
2202 vop80);
2203 // skip unnecessary prefetch of (&ip_next_T0[80])
2204 vop88 = _mm256_fmadd_ps(
2205 vwgt,
2206 _mm256_castsi256_ps(_mm256_slli_epi32(
2207 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2208 reinterpret_cast<const __m128i*>(ip + (88)))),
2209 16)),
2210 vop88);
2211 // skip unnecessary prefetch of (&ip_next_T0[88])
2212 vop96 = _mm256_fmadd_ps(
2213 vwgt,
2214 _mm256_castsi256_ps(_mm256_slli_epi32(
2215 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2216 reinterpret_cast<const __m128i*>(ip + (96)))),
2217 16)),
2218 vop96);
2219 _mm_prefetch(
2220 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2221 vop104 = _mm256_fmadd_ps(
2222 vwgt,
2223 _mm256_castsi256_ps(_mm256_slli_epi32(
2224 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2225 reinterpret_cast<const __m128i*>(ip + (104)))),
2226 16)),
2227 vop104);
2228 // skip unnecessary prefetch of (&ip_next_T0[104])
2229 vop112 = _mm256_fmadd_ps(
2230 vwgt,
2231 _mm256_castsi256_ps(_mm256_slli_epi32(
2232 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2233 reinterpret_cast<const __m128i*>(ip + (112)))),
2234 16)),
2235 vop112);
2236 // skip unnecessary prefetch of (&ip_next_T0[112])
2237 vop120 = _mm256_fmadd_ps(
2238 vwgt,
2239 _mm256_castsi256_ps(_mm256_slli_epi32(
2240 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2241 reinterpret_cast<const __m128i*>(ip + (120)))),
2242 16)),
2243 vop120);
2244 // skip unnecessary prefetch of (&ip_next_T0[120])
2245 }
2246 if (!normalize_by_lengths || length == 0) {
2247 _mm256_storeu_ps(&op[0], vop0);
2248 _mm256_storeu_ps(&op[8], vop8);
2249 _mm256_storeu_ps(&op[16], vop16);
2250 _mm256_storeu_ps(&op[24], vop24);
2251 _mm256_storeu_ps(&op[32], vop32);
2252 _mm256_storeu_ps(&op[40], vop40);
2253 _mm256_storeu_ps(&op[48], vop48);
2254 _mm256_storeu_ps(&op[56], vop56);
2255 _mm256_storeu_ps(&op[64], vop64);
2256 _mm256_storeu_ps(&op[72], vop72);
2257 _mm256_storeu_ps(&op[80], vop80);
2258 _mm256_storeu_ps(&op[88], vop88);
2259 _mm256_storeu_ps(&op[96], vop96);
2260 _mm256_storeu_ps(&op[104], vop104);
2261 _mm256_storeu_ps(&op[112], vop112);
2262 _mm256_storeu_ps(&op[120], vop120);
2263 } else {
2264 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2265 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2266 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2267 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2268 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2269 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2270 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2271 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2272 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2273 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2274 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2275 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2276 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2277 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2278 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2279 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2280 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2281 }
2282 }
2283 } else if (block_size == 64) {
2284 // unrolling 8 times
2285 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2286 float* op = &out[rangeIndex * block_size];
2287 __m256 vop0 = _mm256_setzero_ps();
2288 __m256 vop8 = _mm256_setzero_ps();
2289 __m256 vop16 = _mm256_setzero_ps();
2290 __m256 vop24 = _mm256_setzero_ps();
2291 __m256 vop32 = _mm256_setzero_ps();
2292 __m256 vop40 = _mm256_setzero_ps();
2293 __m256 vop48 = _mm256_setzero_ps();
2294 __m256 vop56 = _mm256_setzero_ps();
2295 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2296 return false;
2297 }
2298 int64_t end_offset = offsets[rangeIndex + 1];
2299 int64_t length = end_offset - offsets[rangeIndex];
2300 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2301 ++dataInd) {
2302 const int idx = indices[dataInd];
2303 if (idx < 0 || idx >= data_size) {
2304 return false;
2305 }
2306 float wgt = 1.f;
2307 if (weights) {
2308 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2309 }
2310 __m256 vwgt = _mm256_set1_ps(wgt);
2311 const at::BFloat16* ip = &input[idx * fused_block_size];
2312 const int next_T0 = (dataInd < index_size - prefdist_T0)
2313 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2314 ? (dataInd + prefdist_T0)
2315 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2316 : dataInd;
2317 const int idx_pref_T0 = indices[next_T0];
2318 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2319 return false;
2320 }
2321 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2322 vop0 = _mm256_fmadd_ps(
2323 vwgt,
2324 _mm256_castsi256_ps(_mm256_slli_epi32(
2325 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2326 reinterpret_cast<const __m128i*>(ip + (0)))),
2327 16)),
2328 vop0);
2329 _mm_prefetch(
2330 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2331 vop8 = _mm256_fmadd_ps(
2332 vwgt,
2333 _mm256_castsi256_ps(_mm256_slli_epi32(
2334 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2335 reinterpret_cast<const __m128i*>(ip + (8)))),
2336 16)),
2337 vop8);
2338 // skip unnecessary prefetch of (&ip_next_T0[8])
2339 vop16 = _mm256_fmadd_ps(
2340 vwgt,
2341 _mm256_castsi256_ps(_mm256_slli_epi32(
2342 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2343 reinterpret_cast<const __m128i*>(ip + (16)))),
2344 16)),
2345 vop16);
2346 // skip unnecessary prefetch of (&ip_next_T0[16])
2347 vop24 = _mm256_fmadd_ps(
2348 vwgt,
2349 _mm256_castsi256_ps(_mm256_slli_epi32(
2350 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2351 reinterpret_cast<const __m128i*>(ip + (24)))),
2352 16)),
2353 vop24);
2354 // skip unnecessary prefetch of (&ip_next_T0[24])
2355 vop32 = _mm256_fmadd_ps(
2356 vwgt,
2357 _mm256_castsi256_ps(_mm256_slli_epi32(
2358 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2359 reinterpret_cast<const __m128i*>(ip + (32)))),
2360 16)),
2361 vop32);
2362 _mm_prefetch(
2363 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2364 vop40 = _mm256_fmadd_ps(
2365 vwgt,
2366 _mm256_castsi256_ps(_mm256_slli_epi32(
2367 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2368 reinterpret_cast<const __m128i*>(ip + (40)))),
2369 16)),
2370 vop40);
2371 // skip unnecessary prefetch of (&ip_next_T0[40])
2372 vop48 = _mm256_fmadd_ps(
2373 vwgt,
2374 _mm256_castsi256_ps(_mm256_slli_epi32(
2375 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2376 reinterpret_cast<const __m128i*>(ip + (48)))),
2377 16)),
2378 vop48);
2379 // skip unnecessary prefetch of (&ip_next_T0[48])
2380 vop56 = _mm256_fmadd_ps(
2381 vwgt,
2382 _mm256_castsi256_ps(_mm256_slli_epi32(
2383 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2384 reinterpret_cast<const __m128i*>(ip + (56)))),
2385 16)),
2386 vop56);
2387 // skip unnecessary prefetch of (&ip_next_T0[56])
2388 }
2389 if (!normalize_by_lengths || length == 0) {
2390 _mm256_storeu_ps(&op[0], vop0);
2391 _mm256_storeu_ps(&op[8], vop8);
2392 _mm256_storeu_ps(&op[16], vop16);
2393 _mm256_storeu_ps(&op[24], vop24);
2394 _mm256_storeu_ps(&op[32], vop32);
2395 _mm256_storeu_ps(&op[40], vop40);
2396 _mm256_storeu_ps(&op[48], vop48);
2397 _mm256_storeu_ps(&op[56], vop56);
2398 } else {
2399 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2400 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2401 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2402 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2403 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2404 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2405 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2406 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2407 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2408 }
2409 }
2410 } else if (block_size == 32) {
2411 // unrolling 4 times
2412 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2413 float* op = &out[rangeIndex * block_size];
2414 __m256 vop0 = _mm256_setzero_ps();
2415 __m256 vop8 = _mm256_setzero_ps();
2416 __m256 vop16 = _mm256_setzero_ps();
2417 __m256 vop24 = _mm256_setzero_ps();
2418 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2419 return false;
2420 }
2421 int64_t end_offset = offsets[rangeIndex + 1];
2422 int64_t length = end_offset - offsets[rangeIndex];
2423 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2424 ++dataInd) {
2425 const int idx = indices[dataInd];
2426 if (idx < 0 || idx >= data_size) {
2427 return false;
2428 }
2429 float wgt = 1.f;
2430 if (weights) {
2431 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2432 }
2433 __m256 vwgt = _mm256_set1_ps(wgt);
2434 const at::BFloat16* ip = &input[idx * fused_block_size];
2435 const int next_T0 = (dataInd < index_size - prefdist_T0)
2436 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2437 ? (dataInd + prefdist_T0)
2438 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2439 : dataInd;
2440 const int idx_pref_T0 = indices[next_T0];
2441 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2442 return false;
2443 }
2444 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2445 vop0 = _mm256_fmadd_ps(
2446 vwgt,
2447 _mm256_castsi256_ps(_mm256_slli_epi32(
2448 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2449 reinterpret_cast<const __m128i*>(ip + (0)))),
2450 16)),
2451 vop0);
2452 _mm_prefetch(
2453 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2454 vop8 = _mm256_fmadd_ps(
2455 vwgt,
2456 _mm256_castsi256_ps(_mm256_slli_epi32(
2457 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2458 reinterpret_cast<const __m128i*>(ip + (8)))),
2459 16)),
2460 vop8);
2461 // skip unnecessary prefetch of (&ip_next_T0[8])
2462 vop16 = _mm256_fmadd_ps(
2463 vwgt,
2464 _mm256_castsi256_ps(_mm256_slli_epi32(
2465 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2466 reinterpret_cast<const __m128i*>(ip + (16)))),
2467 16)),
2468 vop16);
2469 // skip unnecessary prefetch of (&ip_next_T0[16])
2470 vop24 = _mm256_fmadd_ps(
2471 vwgt,
2472 _mm256_castsi256_ps(_mm256_slli_epi32(
2473 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2474 reinterpret_cast<const __m128i*>(ip + (24)))),
2475 16)),
2476 vop24);
2477 // skip unnecessary prefetch of (&ip_next_T0[24])
2478 }
2479 if (!normalize_by_lengths || length == 0) {
2480 _mm256_storeu_ps(&op[0], vop0);
2481 _mm256_storeu_ps(&op[8], vop8);
2482 _mm256_storeu_ps(&op[16], vop16);
2483 _mm256_storeu_ps(&op[24], vop24);
2484 } else {
2485 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2486 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2487 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2488 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2489 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2490 }
2491 }
2492 } else if (block_size == 16) {
2493 // unrolling 2 times
2494 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2495 float* op = &out[rangeIndex * block_size];
2496 __m256 vop0 = _mm256_setzero_ps();
2497 __m256 vop8 = _mm256_setzero_ps();
2498 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2499 return false;
2500 }
2501 int64_t end_offset = offsets[rangeIndex + 1];
2502 int64_t length = end_offset - offsets[rangeIndex];
2503 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2504 ++dataInd) {
2505 const int idx = indices[dataInd];
2506 if (idx < 0 || idx >= data_size) {
2507 return false;
2508 }
2509 float wgt = 1.f;
2510 if (weights) {
2511 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2512 }
2513 __m256 vwgt = _mm256_set1_ps(wgt);
2514 const at::BFloat16* ip = &input[idx * fused_block_size];
2515 const int next_T0 = (dataInd < index_size - prefdist_T0)
2516 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2517 ? (dataInd + prefdist_T0)
2518 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2519 : dataInd;
2520 const int idx_pref_T0 = indices[next_T0];
2521 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2522 return false;
2523 }
2524 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2525 vop0 = _mm256_fmadd_ps(
2526 vwgt,
2527 _mm256_castsi256_ps(_mm256_slli_epi32(
2528 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2529 reinterpret_cast<const __m128i*>(ip + (0)))),
2530 16)),
2531 vop0);
2532 _mm_prefetch(
2533 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2534 vop8 = _mm256_fmadd_ps(
2535 vwgt,
2536 _mm256_castsi256_ps(_mm256_slli_epi32(
2537 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2538 reinterpret_cast<const __m128i*>(ip + (8)))),
2539 16)),
2540 vop8);
2541 // skip unnecessary prefetch of (&ip_next_T0[8])
2542 }
2543 if (!normalize_by_lengths || length == 0) {
2544 _mm256_storeu_ps(&op[0], vop0);
2545 _mm256_storeu_ps(&op[8], vop8);
2546 } else {
2547 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2548 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2549 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2550 }
2551 }
2552 } else {
2553 // generic code
2554 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
2555 alignas(64) at::BFloat16 vtmp1[8] = {0};
2556 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2557 float* op = &out[rangeIndex * block_size];
2558 int64_t j = 0;
2559 for (; j + 8 <= block_size; j += 8) {
2560 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2561 }
2562 for (; j < block_size; j++) {
2563 op[j] = 0.0f;
2564 }
2565 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2566 return false;
2567 }
2568 int64_t end_offset = offsets[rangeIndex + 1];
2569 int64_t length = end_offset - offsets[rangeIndex];
2570 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2571 ++dataInd) {
2572 const int idx = indices[dataInd];
2573 if (idx < 0 || idx >= data_size) {
2574 return false;
2575 }
2576 float wgt = 1.f;
2577 if (weights) {
2578 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2579 }
2580 __m256 vwgt = _mm256_set1_ps(wgt);
2581 const at::BFloat16* ip = &input[idx * fused_block_size];
2582 const int next_T0 = (dataInd < index_size - prefdist_T0)
2583 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2584 ? (dataInd + prefdist_T0)
2585 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2586 : dataInd;
2587 const int idx_pref_T0 = indices[next_T0];
2588 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2589 return false;
2590 }
2591 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2592 j = 0;
2593 for (; j + 8 <= block_size; j += 8) {
2594 _mm256_storeu_ps(
2595 &op[j],
2596 _mm256_fmadd_ps(
2597 vwgt,
2598 _mm256_castsi256_ps(_mm256_slli_epi32(
2599 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2600 reinterpret_cast<const __m128i*>(&ip[j]))),
2601 16)),
2602 _mm256_loadu_ps(&op[j])));
2603 _mm_prefetch(
2604 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2605 }
2606 for (; j < block_size; j++) {
2607 vtmp1[0] = ip[j];
2608 __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
2609 _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
2610 16));
2611 op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
2612 }
2613 }
2614 if (normalize_by_lengths && length) {
2615 float len_inv = 1.0f / length;
2616 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2617 j = 0;
2618 for (; j + 8 <= block_size; j += 8) {
2619 _mm256_storeu_ps(
2620 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2621 }
2622 for (; j < block_size; j++) {
2623 op[j] = len_inv * op[j];
2624 }
2625 }
2626 }
2627 }
2628 return dataInd == index_size;
2629 }
EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2630 bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(
2631 const int64_t block_size,
2632 const int64_t output_size,
2633 const int64_t index_size,
2634 const int64_t data_size,
2635 const at::BFloat16* input,
2636 const int* indices,
2637 const int* offsets,
2638 const float* weights,
2639 const float* scale_bias,
2640 bool normalize_by_lengths,
2641 float* out) {
2642 return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(
2643 block_size,
2644 output_size,
2645 index_size,
2646 data_size,
2647 input,
2648 indices,
2649 offsets,
2650 weights,
2651 scale_bias,
2652 normalize_by_lengths,
2653 out);
2654 }
EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2655 bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(
2656 const int64_t block_size,
2657 const int64_t output_size,
2658 const int64_t index_size,
2659 const int64_t data_size,
2660 const at::BFloat16* input,
2661 const int* indices,
2662 const int* offsets,
2663 const float* weights,
2664 const float* scale_bias,
2665 bool normalize_by_lengths,
2666 float* out) {
2667 return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(
2668 block_size,
2669 output_size,
2670 index_size,
2671 data_size,
2672 input,
2673 indices,
2674 offsets,
2675 weights,
2676 scale_bias,
2677 normalize_by_lengths,
2678 out);
2679 }
2680
2681 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2682 static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(
2683 const int64_t block_size,
2684 const int64_t output_size,
2685 const int64_t index_size,
2686 const int64_t data_size,
2687 const at::BFloat16* input,
2688 const int64_t* indices,
2689 const int64_t* offsets,
2690 const float* weights,
2691 const float* scale_bias,
2692 bool normalize_by_lengths,
2693 float* out) {
2694 const int64_t prefdist_T0 = 16;
2695 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2696 const int64_t fused_block_size = block_size + 0;
2697 int64_t dataInd = 0;
2698 if (block_size == 128) {
2699 // unrolling 16 times
2700 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2701 float* op = &out[rangeIndex * block_size];
2702 __m256 vop0 = _mm256_setzero_ps();
2703 __m256 vop8 = _mm256_setzero_ps();
2704 __m256 vop16 = _mm256_setzero_ps();
2705 __m256 vop24 = _mm256_setzero_ps();
2706 __m256 vop32 = _mm256_setzero_ps();
2707 __m256 vop40 = _mm256_setzero_ps();
2708 __m256 vop48 = _mm256_setzero_ps();
2709 __m256 vop56 = _mm256_setzero_ps();
2710 __m256 vop64 = _mm256_setzero_ps();
2711 __m256 vop72 = _mm256_setzero_ps();
2712 __m256 vop80 = _mm256_setzero_ps();
2713 __m256 vop88 = _mm256_setzero_ps();
2714 __m256 vop96 = _mm256_setzero_ps();
2715 __m256 vop104 = _mm256_setzero_ps();
2716 __m256 vop112 = _mm256_setzero_ps();
2717 __m256 vop120 = _mm256_setzero_ps();
2718 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2719 return false;
2720 }
2721 int64_t end_offset = offsets[rangeIndex + 1];
2722 int64_t length = end_offset - offsets[rangeIndex];
2723 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2724 ++dataInd) {
2725 const int64_t idx = indices[dataInd];
2726 if (idx < 0 || idx >= data_size) {
2727 return false;
2728 }
2729 float wgt = 1.f;
2730 if (weights) {
2731 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2732 }
2733 __m256 vwgt = _mm256_set1_ps(wgt);
2734 const at::BFloat16* ip = &input[idx * fused_block_size];
2735 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2736 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2737 ? (dataInd + prefdist_T0)
2738 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2739 : dataInd;
2740 const int64_t idx_pref_T0 = indices[next_T0];
2741 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2742 return false;
2743 }
2744 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2745 vop0 = _mm256_fmadd_ps(
2746 vwgt,
2747 _mm256_castsi256_ps(_mm256_slli_epi32(
2748 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2749 reinterpret_cast<const __m128i*>(ip + (0)))),
2750 16)),
2751 vop0);
2752 _mm_prefetch(
2753 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2754 vop8 = _mm256_fmadd_ps(
2755 vwgt,
2756 _mm256_castsi256_ps(_mm256_slli_epi32(
2757 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2758 reinterpret_cast<const __m128i*>(ip + (8)))),
2759 16)),
2760 vop8);
2761 // skip unnecessary prefetch of (&ip_next_T0[8])
2762 vop16 = _mm256_fmadd_ps(
2763 vwgt,
2764 _mm256_castsi256_ps(_mm256_slli_epi32(
2765 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2766 reinterpret_cast<const __m128i*>(ip + (16)))),
2767 16)),
2768 vop16);
2769 // skip unnecessary prefetch of (&ip_next_T0[16])
2770 vop24 = _mm256_fmadd_ps(
2771 vwgt,
2772 _mm256_castsi256_ps(_mm256_slli_epi32(
2773 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2774 reinterpret_cast<const __m128i*>(ip + (24)))),
2775 16)),
2776 vop24);
2777 // skip unnecessary prefetch of (&ip_next_T0[24])
2778 vop32 = _mm256_fmadd_ps(
2779 vwgt,
2780 _mm256_castsi256_ps(_mm256_slli_epi32(
2781 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2782 reinterpret_cast<const __m128i*>(ip + (32)))),
2783 16)),
2784 vop32);
2785 _mm_prefetch(
2786 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2787 vop40 = _mm256_fmadd_ps(
2788 vwgt,
2789 _mm256_castsi256_ps(_mm256_slli_epi32(
2790 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2791 reinterpret_cast<const __m128i*>(ip + (40)))),
2792 16)),
2793 vop40);
2794 // skip unnecessary prefetch of (&ip_next_T0[40])
2795 vop48 = _mm256_fmadd_ps(
2796 vwgt,
2797 _mm256_castsi256_ps(_mm256_slli_epi32(
2798 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2799 reinterpret_cast<const __m128i*>(ip + (48)))),
2800 16)),
2801 vop48);
2802 // skip unnecessary prefetch of (&ip_next_T0[48])
2803 vop56 = _mm256_fmadd_ps(
2804 vwgt,
2805 _mm256_castsi256_ps(_mm256_slli_epi32(
2806 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2807 reinterpret_cast<const __m128i*>(ip + (56)))),
2808 16)),
2809 vop56);
2810 // skip unnecessary prefetch of (&ip_next_T0[56])
2811 vop64 = _mm256_fmadd_ps(
2812 vwgt,
2813 _mm256_castsi256_ps(_mm256_slli_epi32(
2814 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2815 reinterpret_cast<const __m128i*>(ip + (64)))),
2816 16)),
2817 vop64);
2818 _mm_prefetch(
2819 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2820 vop72 = _mm256_fmadd_ps(
2821 vwgt,
2822 _mm256_castsi256_ps(_mm256_slli_epi32(
2823 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2824 reinterpret_cast<const __m128i*>(ip + (72)))),
2825 16)),
2826 vop72);
2827 // skip unnecessary prefetch of (&ip_next_T0[72])
2828 vop80 = _mm256_fmadd_ps(
2829 vwgt,
2830 _mm256_castsi256_ps(_mm256_slli_epi32(
2831 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2832 reinterpret_cast<const __m128i*>(ip + (80)))),
2833 16)),
2834 vop80);
2835 // skip unnecessary prefetch of (&ip_next_T0[80])
2836 vop88 = _mm256_fmadd_ps(
2837 vwgt,
2838 _mm256_castsi256_ps(_mm256_slli_epi32(
2839 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2840 reinterpret_cast<const __m128i*>(ip + (88)))),
2841 16)),
2842 vop88);
2843 // skip unnecessary prefetch of (&ip_next_T0[88])
2844 vop96 = _mm256_fmadd_ps(
2845 vwgt,
2846 _mm256_castsi256_ps(_mm256_slli_epi32(
2847 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2848 reinterpret_cast<const __m128i*>(ip + (96)))),
2849 16)),
2850 vop96);
2851 _mm_prefetch(
2852 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2853 vop104 = _mm256_fmadd_ps(
2854 vwgt,
2855 _mm256_castsi256_ps(_mm256_slli_epi32(
2856 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2857 reinterpret_cast<const __m128i*>(ip + (104)))),
2858 16)),
2859 vop104);
2860 // skip unnecessary prefetch of (&ip_next_T0[104])
2861 vop112 = _mm256_fmadd_ps(
2862 vwgt,
2863 _mm256_castsi256_ps(_mm256_slli_epi32(
2864 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2865 reinterpret_cast<const __m128i*>(ip + (112)))),
2866 16)),
2867 vop112);
2868 // skip unnecessary prefetch of (&ip_next_T0[112])
2869 vop120 = _mm256_fmadd_ps(
2870 vwgt,
2871 _mm256_castsi256_ps(_mm256_slli_epi32(
2872 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2873 reinterpret_cast<const __m128i*>(ip + (120)))),
2874 16)),
2875 vop120);
2876 // skip unnecessary prefetch of (&ip_next_T0[120])
2877 }
2878 if (!normalize_by_lengths || length == 0) {
2879 _mm256_storeu_ps(&op[0], vop0);
2880 _mm256_storeu_ps(&op[8], vop8);
2881 _mm256_storeu_ps(&op[16], vop16);
2882 _mm256_storeu_ps(&op[24], vop24);
2883 _mm256_storeu_ps(&op[32], vop32);
2884 _mm256_storeu_ps(&op[40], vop40);
2885 _mm256_storeu_ps(&op[48], vop48);
2886 _mm256_storeu_ps(&op[56], vop56);
2887 _mm256_storeu_ps(&op[64], vop64);
2888 _mm256_storeu_ps(&op[72], vop72);
2889 _mm256_storeu_ps(&op[80], vop80);
2890 _mm256_storeu_ps(&op[88], vop88);
2891 _mm256_storeu_ps(&op[96], vop96);
2892 _mm256_storeu_ps(&op[104], vop104);
2893 _mm256_storeu_ps(&op[112], vop112);
2894 _mm256_storeu_ps(&op[120], vop120);
2895 } else {
2896 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2897 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2898 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2899 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2900 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2901 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2902 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2903 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2904 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2905 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2906 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2907 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2908 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2909 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2910 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2911 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2912 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2913 }
2914 }
2915 } else if (block_size == 64) {
2916 // unrolling 8 times
2917 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2918 float* op = &out[rangeIndex * block_size];
2919 __m256 vop0 = _mm256_setzero_ps();
2920 __m256 vop8 = _mm256_setzero_ps();
2921 __m256 vop16 = _mm256_setzero_ps();
2922 __m256 vop24 = _mm256_setzero_ps();
2923 __m256 vop32 = _mm256_setzero_ps();
2924 __m256 vop40 = _mm256_setzero_ps();
2925 __m256 vop48 = _mm256_setzero_ps();
2926 __m256 vop56 = _mm256_setzero_ps();
2927 if (dataInd != offsets[rangeIndex] - offsets[0]) {
2928 return false;
2929 }
2930 int64_t end_offset = offsets[rangeIndex + 1];
2931 int64_t length = end_offset - offsets[rangeIndex];
2932 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2933 ++dataInd) {
2934 const int64_t idx = indices[dataInd];
2935 if (idx < 0 || idx >= data_size) {
2936 return false;
2937 }
2938 float wgt = 1.f;
2939 if (weights) {
2940 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2941 }
2942 __m256 vwgt = _mm256_set1_ps(wgt);
2943 const at::BFloat16* ip = &input[idx * fused_block_size];
2944 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2945 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2946 ? (dataInd + prefdist_T0)
2947 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2948 : dataInd;
2949 const int64_t idx_pref_T0 = indices[next_T0];
2950 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2951 return false;
2952 }
2953 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2954 vop0 = _mm256_fmadd_ps(
2955 vwgt,
2956 _mm256_castsi256_ps(_mm256_slli_epi32(
2957 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2958 reinterpret_cast<const __m128i*>(ip + (0)))),
2959 16)),
2960 vop0);
2961 _mm_prefetch(
2962 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2963 vop8 = _mm256_fmadd_ps(
2964 vwgt,
2965 _mm256_castsi256_ps(_mm256_slli_epi32(
2966 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2967 reinterpret_cast<const __m128i*>(ip + (8)))),
2968 16)),
2969 vop8);
2970 // skip unnecessary prefetch of (&ip_next_T0[8])
2971 vop16 = _mm256_fmadd_ps(
2972 vwgt,
2973 _mm256_castsi256_ps(_mm256_slli_epi32(
2974 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2975 reinterpret_cast<const __m128i*>(ip + (16)))),
2976 16)),
2977 vop16);
2978 // skip unnecessary prefetch of (&ip_next_T0[16])
2979 vop24 = _mm256_fmadd_ps(
2980 vwgt,
2981 _mm256_castsi256_ps(_mm256_slli_epi32(
2982 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2983 reinterpret_cast<const __m128i*>(ip + (24)))),
2984 16)),
2985 vop24);
2986 // skip unnecessary prefetch of (&ip_next_T0[24])
2987 vop32 = _mm256_fmadd_ps(
2988 vwgt,
2989 _mm256_castsi256_ps(_mm256_slli_epi32(
2990 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2991 reinterpret_cast<const __m128i*>(ip + (32)))),
2992 16)),
2993 vop32);
2994 _mm_prefetch(
2995 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2996 vop40 = _mm256_fmadd_ps(
2997 vwgt,
2998 _mm256_castsi256_ps(_mm256_slli_epi32(
2999 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3000 reinterpret_cast<const __m128i*>(ip + (40)))),
3001 16)),
3002 vop40);
3003 // skip unnecessary prefetch of (&ip_next_T0[40])
3004 vop48 = _mm256_fmadd_ps(
3005 vwgt,
3006 _mm256_castsi256_ps(_mm256_slli_epi32(
3007 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3008 reinterpret_cast<const __m128i*>(ip + (48)))),
3009 16)),
3010 vop48);
3011 // skip unnecessary prefetch of (&ip_next_T0[48])
3012 vop56 = _mm256_fmadd_ps(
3013 vwgt,
3014 _mm256_castsi256_ps(_mm256_slli_epi32(
3015 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3016 reinterpret_cast<const __m128i*>(ip + (56)))),
3017 16)),
3018 vop56);
3019 // skip unnecessary prefetch of (&ip_next_T0[56])
3020 }
3021 if (!normalize_by_lengths || length == 0) {
3022 _mm256_storeu_ps(&op[0], vop0);
3023 _mm256_storeu_ps(&op[8], vop8);
3024 _mm256_storeu_ps(&op[16], vop16);
3025 _mm256_storeu_ps(&op[24], vop24);
3026 _mm256_storeu_ps(&op[32], vop32);
3027 _mm256_storeu_ps(&op[40], vop40);
3028 _mm256_storeu_ps(&op[48], vop48);
3029 _mm256_storeu_ps(&op[56], vop56);
3030 } else {
3031 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3032 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3033 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3034 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3035 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3036 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3037 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3038 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3039 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3040 }
3041 }
3042 } else if (block_size == 32) {
3043 // unrolling 4 times
3044 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3045 float* op = &out[rangeIndex * block_size];
3046 __m256 vop0 = _mm256_setzero_ps();
3047 __m256 vop8 = _mm256_setzero_ps();
3048 __m256 vop16 = _mm256_setzero_ps();
3049 __m256 vop24 = _mm256_setzero_ps();
3050 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3051 return false;
3052 }
3053 int64_t end_offset = offsets[rangeIndex + 1];
3054 int64_t length = end_offset - offsets[rangeIndex];
3055 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3056 ++dataInd) {
3057 const int64_t idx = indices[dataInd];
3058 if (idx < 0 || idx >= data_size) {
3059 return false;
3060 }
3061 float wgt = 1.f;
3062 if (weights) {
3063 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3064 }
3065 __m256 vwgt = _mm256_set1_ps(wgt);
3066 const at::BFloat16* ip = &input[idx * fused_block_size];
3067 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3068 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3069 ? (dataInd + prefdist_T0)
3070 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3071 : dataInd;
3072 const int64_t idx_pref_T0 = indices[next_T0];
3073 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3074 return false;
3075 }
3076 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3077 vop0 = _mm256_fmadd_ps(
3078 vwgt,
3079 _mm256_castsi256_ps(_mm256_slli_epi32(
3080 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3081 reinterpret_cast<const __m128i*>(ip + (0)))),
3082 16)),
3083 vop0);
3084 _mm_prefetch(
3085 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3086 vop8 = _mm256_fmadd_ps(
3087 vwgt,
3088 _mm256_castsi256_ps(_mm256_slli_epi32(
3089 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3090 reinterpret_cast<const __m128i*>(ip + (8)))),
3091 16)),
3092 vop8);
3093 // skip unnecessary prefetch of (&ip_next_T0[8])
3094 vop16 = _mm256_fmadd_ps(
3095 vwgt,
3096 _mm256_castsi256_ps(_mm256_slli_epi32(
3097 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3098 reinterpret_cast<const __m128i*>(ip + (16)))),
3099 16)),
3100 vop16);
3101 // skip unnecessary prefetch of (&ip_next_T0[16])
3102 vop24 = _mm256_fmadd_ps(
3103 vwgt,
3104 _mm256_castsi256_ps(_mm256_slli_epi32(
3105 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3106 reinterpret_cast<const __m128i*>(ip + (24)))),
3107 16)),
3108 vop24);
3109 // skip unnecessary prefetch of (&ip_next_T0[24])
3110 }
3111 if (!normalize_by_lengths || length == 0) {
3112 _mm256_storeu_ps(&op[0], vop0);
3113 _mm256_storeu_ps(&op[8], vop8);
3114 _mm256_storeu_ps(&op[16], vop16);
3115 _mm256_storeu_ps(&op[24], vop24);
3116 } else {
3117 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3118 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3119 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3120 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3121 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3122 }
3123 }
3124 } else if (block_size == 16) {
3125 // unrolling 2 times
3126 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3127 float* op = &out[rangeIndex * block_size];
3128 __m256 vop0 = _mm256_setzero_ps();
3129 __m256 vop8 = _mm256_setzero_ps();
3130 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3131 return false;
3132 }
3133 int64_t end_offset = offsets[rangeIndex + 1];
3134 int64_t length = end_offset - offsets[rangeIndex];
3135 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3136 ++dataInd) {
3137 const int64_t idx = indices[dataInd];
3138 if (idx < 0 || idx >= data_size) {
3139 return false;
3140 }
3141 float wgt = 1.f;
3142 if (weights) {
3143 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3144 }
3145 __m256 vwgt = _mm256_set1_ps(wgt);
3146 const at::BFloat16* ip = &input[idx * fused_block_size];
3147 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3148 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3149 ? (dataInd + prefdist_T0)
3150 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3151 : dataInd;
3152 const int64_t idx_pref_T0 = indices[next_T0];
3153 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3154 return false;
3155 }
3156 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3157 vop0 = _mm256_fmadd_ps(
3158 vwgt,
3159 _mm256_castsi256_ps(_mm256_slli_epi32(
3160 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3161 reinterpret_cast<const __m128i*>(ip + (0)))),
3162 16)),
3163 vop0);
3164 _mm_prefetch(
3165 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3166 vop8 = _mm256_fmadd_ps(
3167 vwgt,
3168 _mm256_castsi256_ps(_mm256_slli_epi32(
3169 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3170 reinterpret_cast<const __m128i*>(ip + (8)))),
3171 16)),
3172 vop8);
3173 // skip unnecessary prefetch of (&ip_next_T0[8])
3174 }
3175 if (!normalize_by_lengths || length == 0) {
3176 _mm256_storeu_ps(&op[0], vop0);
3177 _mm256_storeu_ps(&op[8], vop8);
3178 } else {
3179 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3180 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3181 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3182 }
3183 }
3184 } else {
3185 // generic code
3186 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3187 alignas(64) at::BFloat16 vtmp1[8] = {0};
3188 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3189 float* op = &out[rangeIndex * block_size];
3190 int64_t j = 0;
3191 for (; j + 8 <= block_size; j += 8) {
3192 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3193 }
3194 for (; j < block_size; j++) {
3195 op[j] = 0.0f;
3196 }
3197 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3198 return false;
3199 }
3200 int64_t end_offset = offsets[rangeIndex + 1];
3201 int64_t length = end_offset - offsets[rangeIndex];
3202 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3203 ++dataInd) {
3204 const int64_t idx = indices[dataInd];
3205 if (idx < 0 || idx >= data_size) {
3206 return false;
3207 }
3208 float wgt = 1.f;
3209 if (weights) {
3210 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3211 }
3212 __m256 vwgt = _mm256_set1_ps(wgt);
3213 const at::BFloat16* ip = &input[idx * fused_block_size];
3214 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3215 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3216 ? (dataInd + prefdist_T0)
3217 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3218 : dataInd;
3219 const int64_t idx_pref_T0 = indices[next_T0];
3220 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3221 return false;
3222 }
3223 const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3224 j = 0;
3225 for (; j + 8 <= block_size; j += 8) {
3226 _mm256_storeu_ps(
3227 &op[j],
3228 _mm256_fmadd_ps(
3229 vwgt,
3230 _mm256_castsi256_ps(_mm256_slli_epi32(
3231 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3232 reinterpret_cast<const __m128i*>(&ip[j]))),
3233 16)),
3234 _mm256_loadu_ps(&op[j])));
3235 _mm_prefetch(
3236 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3237 }
3238 for (; j < block_size; j++) {
3239 vtmp1[0] = ip[j];
3240 __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
3241 _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
3242 16));
3243 op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
3244 }
3245 }
3246 if (normalize_by_lengths && length) {
3247 float len_inv = 1.0f / length;
3248 __m256 vlen_inv = _mm256_set1_ps(len_inv);
3249 j = 0;
3250 for (; j + 8 <= block_size; j += 8) {
3251 _mm256_storeu_ps(
3252 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3253 }
3254 for (; j < block_size; j++) {
3255 op[j] = len_inv * op[j];
3256 }
3257 }
3258 }
3259 }
3260 return dataInd == index_size;
3261 }
EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3262 bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(
3263 const int64_t block_size,
3264 const int64_t output_size,
3265 const int64_t index_size,
3266 const int64_t data_size,
3267 const at::BFloat16* input,
3268 const int64_t* indices,
3269 const int64_t* offsets,
3270 const float* weights,
3271 const float* scale_bias,
3272 bool normalize_by_lengths,
3273 float* out) {
3274 return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(
3275 block_size,
3276 output_size,
3277 index_size,
3278 data_size,
3279 input,
3280 indices,
3281 offsets,
3282 weights,
3283 scale_bias,
3284 normalize_by_lengths,
3285 out);
3286 }
EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3287 bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(
3288 const int64_t block_size,
3289 const int64_t output_size,
3290 const int64_t index_size,
3291 const int64_t data_size,
3292 const at::BFloat16* input,
3293 const int64_t* indices,
3294 const int64_t* offsets,
3295 const float* weights,
3296 const float* scale_bias,
3297 bool normalize_by_lengths,
3298 float* out) {
3299 return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(
3300 block_size,
3301 output_size,
3302 index_size,
3303 data_size,
3304 input,
3305 indices,
3306 offsets,
3307 weights,
3308 scale_bias,
3309 normalize_by_lengths,
3310 out);
3311 }
3312
3313 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3314 static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
3315 const int64_t block_size,
3316 const int64_t output_size,
3317 const int64_t index_size,
3318 const int64_t data_size,
3319 const uint8_t* input,
3320 const int* indices,
3321 const int* offsets,
3322 const float* weights,
3323 const float* scale_bias,
3324 bool normalize_by_lengths,
3325 float* out) {
3326 const int prefdist_T0 = 16;
3327 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3328 const int fused_block_size = block_size + 0;
3329 int64_t dataInd = 0;
3330 if (block_size == 128) {
3331 // unrolling 16 times
3332 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3333 float* op = &out[rangeIndex * block_size];
3334 __m256 vop0 = _mm256_setzero_ps();
3335 __m256 vop8 = _mm256_setzero_ps();
3336 __m256 vop16 = _mm256_setzero_ps();
3337 __m256 vop24 = _mm256_setzero_ps();
3338 __m256 vop32 = _mm256_setzero_ps();
3339 __m256 vop40 = _mm256_setzero_ps();
3340 __m256 vop48 = _mm256_setzero_ps();
3341 __m256 vop56 = _mm256_setzero_ps();
3342 __m256 vop64 = _mm256_setzero_ps();
3343 __m256 vop72 = _mm256_setzero_ps();
3344 __m256 vop80 = _mm256_setzero_ps();
3345 __m256 vop88 = _mm256_setzero_ps();
3346 __m256 vop96 = _mm256_setzero_ps();
3347 __m256 vop104 = _mm256_setzero_ps();
3348 __m256 vop112 = _mm256_setzero_ps();
3349 __m256 vop120 = _mm256_setzero_ps();
3350 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3351 return false;
3352 }
3353 int64_t end_offset = offsets[rangeIndex + 1];
3354 int64_t length = end_offset - offsets[rangeIndex];
3355 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3356 ++dataInd) {
3357 const int idx = indices[dataInd];
3358 if (idx < 0 || idx >= data_size) {
3359 return false;
3360 }
3361 float wgt = 1.f;
3362 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3363 float bio;
3364 if (weights) {
3365 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3366 }
3367 bio = wgt * scale_bias[2 * idx + 1];
3368 wgt = wgt * scale_bias[2 * idx];
3369 __m256 vbio = _mm256_set1_ps(bio);
3370 __m256 vwgt = _mm256_set1_ps(wgt);
3371 const uint8_t* ip = &input[idx * fused_block_size];
3372 const int next_T0 = (dataInd < index_size - prefdist_T0)
3373 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3374 ? (dataInd + prefdist_T0)
3375 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3376 : dataInd;
3377 const int idx_pref_T0 = indices[next_T0];
3378 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3379 return false;
3380 }
3381 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3382 vop0 = _mm256_fmadd_ps(
3383 vwgt,
3384 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3385 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3386 _mm256_add_ps(vop0, vbio));
3387 _mm_prefetch(
3388 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3389 vop8 = _mm256_fmadd_ps(
3390 vwgt,
3391 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3392 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3393 _mm256_add_ps(vop8, vbio));
3394 // skip unnecessary prefetch of (&ip_next_T0[8])
3395 vop16 = _mm256_fmadd_ps(
3396 vwgt,
3397 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3398 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3399 _mm256_add_ps(vop16, vbio));
3400 // skip unnecessary prefetch of (&ip_next_T0[16])
3401 vop24 = _mm256_fmadd_ps(
3402 vwgt,
3403 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3404 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3405 _mm256_add_ps(vop24, vbio));
3406 // skip unnecessary prefetch of (&ip_next_T0[24])
3407 vop32 = _mm256_fmadd_ps(
3408 vwgt,
3409 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3410 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3411 _mm256_add_ps(vop32, vbio));
3412 // skip unnecessary prefetch of (&ip_next_T0[32])
3413 vop40 = _mm256_fmadd_ps(
3414 vwgt,
3415 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3416 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3417 _mm256_add_ps(vop40, vbio));
3418 // skip unnecessary prefetch of (&ip_next_T0[40])
3419 vop48 = _mm256_fmadd_ps(
3420 vwgt,
3421 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3422 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3423 _mm256_add_ps(vop48, vbio));
3424 // skip unnecessary prefetch of (&ip_next_T0[48])
3425 vop56 = _mm256_fmadd_ps(
3426 vwgt,
3427 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3428 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3429 _mm256_add_ps(vop56, vbio));
3430 // skip unnecessary prefetch of (&ip_next_T0[56])
3431 vop64 = _mm256_fmadd_ps(
3432 vwgt,
3433 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3434 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
3435 _mm256_add_ps(vop64, vbio));
3436 _mm_prefetch(
3437 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
3438 vop72 = _mm256_fmadd_ps(
3439 vwgt,
3440 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3441 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
3442 _mm256_add_ps(vop72, vbio));
3443 // skip unnecessary prefetch of (&ip_next_T0[72])
3444 vop80 = _mm256_fmadd_ps(
3445 vwgt,
3446 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3447 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
3448 _mm256_add_ps(vop80, vbio));
3449 // skip unnecessary prefetch of (&ip_next_T0[80])
3450 vop88 = _mm256_fmadd_ps(
3451 vwgt,
3452 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3453 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
3454 _mm256_add_ps(vop88, vbio));
3455 // skip unnecessary prefetch of (&ip_next_T0[88])
3456 vop96 = _mm256_fmadd_ps(
3457 vwgt,
3458 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3459 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
3460 _mm256_add_ps(vop96, vbio));
3461 // skip unnecessary prefetch of (&ip_next_T0[96])
3462 vop104 = _mm256_fmadd_ps(
3463 vwgt,
3464 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3465 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
3466 _mm256_add_ps(vop104, vbio));
3467 // skip unnecessary prefetch of (&ip_next_T0[104])
3468 vop112 = _mm256_fmadd_ps(
3469 vwgt,
3470 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3471 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
3472 _mm256_add_ps(vop112, vbio));
3473 // skip unnecessary prefetch of (&ip_next_T0[112])
3474 vop120 = _mm256_fmadd_ps(
3475 vwgt,
3476 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3477 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
3478 _mm256_add_ps(vop120, vbio));
3479 // skip unnecessary prefetch of (&ip_next_T0[120])
3480 }
3481 if (!normalize_by_lengths || length == 0) {
3482 _mm256_storeu_ps(&op[0], vop0);
3483 _mm256_storeu_ps(&op[8], vop8);
3484 _mm256_storeu_ps(&op[16], vop16);
3485 _mm256_storeu_ps(&op[24], vop24);
3486 _mm256_storeu_ps(&op[32], vop32);
3487 _mm256_storeu_ps(&op[40], vop40);
3488 _mm256_storeu_ps(&op[48], vop48);
3489 _mm256_storeu_ps(&op[56], vop56);
3490 _mm256_storeu_ps(&op[64], vop64);
3491 _mm256_storeu_ps(&op[72], vop72);
3492 _mm256_storeu_ps(&op[80], vop80);
3493 _mm256_storeu_ps(&op[88], vop88);
3494 _mm256_storeu_ps(&op[96], vop96);
3495 _mm256_storeu_ps(&op[104], vop104);
3496 _mm256_storeu_ps(&op[112], vop112);
3497 _mm256_storeu_ps(&op[120], vop120);
3498 } else {
3499 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3500 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3501 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3502 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3503 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3504 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3505 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3506 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3507 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3508 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
3509 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
3510 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
3511 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
3512 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
3513 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
3514 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
3515 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
3516 }
3517 }
3518 } else if (block_size == 64) {
3519 // unrolling 8 times
3520 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3521 float* op = &out[rangeIndex * block_size];
3522 __m256 vop0 = _mm256_setzero_ps();
3523 __m256 vop8 = _mm256_setzero_ps();
3524 __m256 vop16 = _mm256_setzero_ps();
3525 __m256 vop24 = _mm256_setzero_ps();
3526 __m256 vop32 = _mm256_setzero_ps();
3527 __m256 vop40 = _mm256_setzero_ps();
3528 __m256 vop48 = _mm256_setzero_ps();
3529 __m256 vop56 = _mm256_setzero_ps();
3530 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3531 return false;
3532 }
3533 int64_t end_offset = offsets[rangeIndex + 1];
3534 int64_t length = end_offset - offsets[rangeIndex];
3535 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3536 ++dataInd) {
3537 const int idx = indices[dataInd];
3538 if (idx < 0 || idx >= data_size) {
3539 return false;
3540 }
3541 float wgt = 1.f;
3542 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3543 float bio;
3544 if (weights) {
3545 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3546 }
3547 bio = wgt * scale_bias[2 * idx + 1];
3548 wgt = wgt * scale_bias[2 * idx];
3549 __m256 vbio = _mm256_set1_ps(bio);
3550 __m256 vwgt = _mm256_set1_ps(wgt);
3551 const uint8_t* ip = &input[idx * fused_block_size];
3552 const int next_T0 = (dataInd < index_size - prefdist_T0)
3553 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3554 ? (dataInd + prefdist_T0)
3555 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3556 : dataInd;
3557 const int idx_pref_T0 = indices[next_T0];
3558 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3559 return false;
3560 }
3561 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3562 vop0 = _mm256_fmadd_ps(
3563 vwgt,
3564 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3565 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3566 _mm256_add_ps(vop0, vbio));
3567 _mm_prefetch(
3568 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3569 vop8 = _mm256_fmadd_ps(
3570 vwgt,
3571 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3572 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3573 _mm256_add_ps(vop8, vbio));
3574 // skip unnecessary prefetch of (&ip_next_T0[8])
3575 vop16 = _mm256_fmadd_ps(
3576 vwgt,
3577 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3578 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3579 _mm256_add_ps(vop16, vbio));
3580 // skip unnecessary prefetch of (&ip_next_T0[16])
3581 vop24 = _mm256_fmadd_ps(
3582 vwgt,
3583 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3584 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3585 _mm256_add_ps(vop24, vbio));
3586 // skip unnecessary prefetch of (&ip_next_T0[24])
3587 vop32 = _mm256_fmadd_ps(
3588 vwgt,
3589 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3590 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3591 _mm256_add_ps(vop32, vbio));
3592 // skip unnecessary prefetch of (&ip_next_T0[32])
3593 vop40 = _mm256_fmadd_ps(
3594 vwgt,
3595 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3596 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3597 _mm256_add_ps(vop40, vbio));
3598 // skip unnecessary prefetch of (&ip_next_T0[40])
3599 vop48 = _mm256_fmadd_ps(
3600 vwgt,
3601 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3602 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3603 _mm256_add_ps(vop48, vbio));
3604 // skip unnecessary prefetch of (&ip_next_T0[48])
3605 vop56 = _mm256_fmadd_ps(
3606 vwgt,
3607 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3608 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3609 _mm256_add_ps(vop56, vbio));
3610 // skip unnecessary prefetch of (&ip_next_T0[56])
3611 }
3612 if (!normalize_by_lengths || length == 0) {
3613 _mm256_storeu_ps(&op[0], vop0);
3614 _mm256_storeu_ps(&op[8], vop8);
3615 _mm256_storeu_ps(&op[16], vop16);
3616 _mm256_storeu_ps(&op[24], vop24);
3617 _mm256_storeu_ps(&op[32], vop32);
3618 _mm256_storeu_ps(&op[40], vop40);
3619 _mm256_storeu_ps(&op[48], vop48);
3620 _mm256_storeu_ps(&op[56], vop56);
3621 } else {
3622 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3623 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3624 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3625 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3626 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3627 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3628 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3629 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3630 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3631 }
3632 }
3633 } else if (block_size == 32) {
3634 // unrolling 4 times
3635 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3636 float* op = &out[rangeIndex * block_size];
3637 __m256 vop0 = _mm256_setzero_ps();
3638 __m256 vop8 = _mm256_setzero_ps();
3639 __m256 vop16 = _mm256_setzero_ps();
3640 __m256 vop24 = _mm256_setzero_ps();
3641 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3642 return false;
3643 }
3644 int64_t end_offset = offsets[rangeIndex + 1];
3645 int64_t length = end_offset - offsets[rangeIndex];
3646 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3647 ++dataInd) {
3648 const int idx = indices[dataInd];
3649 if (idx < 0 || idx >= data_size) {
3650 return false;
3651 }
3652 float wgt = 1.f;
3653 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3654 float bio;
3655 if (weights) {
3656 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3657 }
3658 bio = wgt * scale_bias[2 * idx + 1];
3659 wgt = wgt * scale_bias[2 * idx];
3660 __m256 vbio = _mm256_set1_ps(bio);
3661 __m256 vwgt = _mm256_set1_ps(wgt);
3662 const uint8_t* ip = &input[idx * fused_block_size];
3663 const int next_T0 = (dataInd < index_size - prefdist_T0)
3664 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3665 ? (dataInd + prefdist_T0)
3666 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3667 : dataInd;
3668 const int idx_pref_T0 = indices[next_T0];
3669 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3670 return false;
3671 }
3672 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3673 vop0 = _mm256_fmadd_ps(
3674 vwgt,
3675 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3676 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3677 _mm256_add_ps(vop0, vbio));
3678 _mm_prefetch(
3679 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3680 vop8 = _mm256_fmadd_ps(
3681 vwgt,
3682 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3683 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3684 _mm256_add_ps(vop8, vbio));
3685 // skip unnecessary prefetch of (&ip_next_T0[8])
3686 vop16 = _mm256_fmadd_ps(
3687 vwgt,
3688 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3689 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3690 _mm256_add_ps(vop16, vbio));
3691 // skip unnecessary prefetch of (&ip_next_T0[16])
3692 vop24 = _mm256_fmadd_ps(
3693 vwgt,
3694 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3695 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3696 _mm256_add_ps(vop24, vbio));
3697 // skip unnecessary prefetch of (&ip_next_T0[24])
3698 }
3699 if (!normalize_by_lengths || length == 0) {
3700 _mm256_storeu_ps(&op[0], vop0);
3701 _mm256_storeu_ps(&op[8], vop8);
3702 _mm256_storeu_ps(&op[16], vop16);
3703 _mm256_storeu_ps(&op[24], vop24);
3704 } else {
3705 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3706 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3707 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3708 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3709 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3710 }
3711 }
3712 } else if (block_size == 16) {
3713 // unrolling 2 times
3714 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3715 float* op = &out[rangeIndex * block_size];
3716 __m256 vop0 = _mm256_setzero_ps();
3717 __m256 vop8 = _mm256_setzero_ps();
3718 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3719 return false;
3720 }
3721 int64_t end_offset = offsets[rangeIndex + 1];
3722 int64_t length = end_offset - offsets[rangeIndex];
3723 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3724 ++dataInd) {
3725 const int idx = indices[dataInd];
3726 if (idx < 0 || idx >= data_size) {
3727 return false;
3728 }
3729 float wgt = 1.f;
3730 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3731 float bio;
3732 if (weights) {
3733 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3734 }
3735 bio = wgt * scale_bias[2 * idx + 1];
3736 wgt = wgt * scale_bias[2 * idx];
3737 __m256 vbio = _mm256_set1_ps(bio);
3738 __m256 vwgt = _mm256_set1_ps(wgt);
3739 const uint8_t* ip = &input[idx * fused_block_size];
3740 const int next_T0 = (dataInd < index_size - prefdist_T0)
3741 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3742 ? (dataInd + prefdist_T0)
3743 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3744 : dataInd;
3745 const int idx_pref_T0 = indices[next_T0];
3746 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3747 return false;
3748 }
3749 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3750 vop0 = _mm256_fmadd_ps(
3751 vwgt,
3752 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3753 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3754 _mm256_add_ps(vop0, vbio));
3755 _mm_prefetch(
3756 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3757 vop8 = _mm256_fmadd_ps(
3758 vwgt,
3759 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3760 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3761 _mm256_add_ps(vop8, vbio));
3762 // skip unnecessary prefetch of (&ip_next_T0[8])
3763 }
3764 if (!normalize_by_lengths || length == 0) {
3765 _mm256_storeu_ps(&op[0], vop0);
3766 _mm256_storeu_ps(&op[8], vop8);
3767 } else {
3768 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3769 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3770 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3771 }
3772 }
3773 } else {
3774 // generic code
3775 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3776 for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3777 float* op = &out[rangeIndex * block_size];
3778 int64_t j = 0;
3779 for (; j + 8 <= block_size; j += 8) {
3780 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3781 }
3782 for (; j < block_size; j++) {
3783 op[j] = 0.0f;
3784 }
3785 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3786 return false;
3787 }
3788 int64_t end_offset = offsets[rangeIndex + 1];
3789 int64_t length = end_offset - offsets[rangeIndex];
3790 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3791 ++dataInd) {
3792 const int idx = indices[dataInd];
3793 if (idx < 0 || idx >= data_size) {
3794 return false;
3795 }
3796 float wgt = 1.f;
3797 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3798 float bio;
3799 if (weights) {
3800 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3801 }
3802 bio = wgt * scale_bias[2 * idx + 1];
3803 wgt = wgt * scale_bias[2 * idx];
3804 __m256 vbio = _mm256_set1_ps(bio);
3805 __m256 vwgt = _mm256_set1_ps(wgt);
3806 const uint8_t* ip = &input[idx * fused_block_size];
3807 const int next_T0 = (dataInd < index_size - prefdist_T0)
3808 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3809 ? (dataInd + prefdist_T0)
3810 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3811 : dataInd;
3812 const int idx_pref_T0 = indices[next_T0];
3813 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3814 return false;
3815 }
3816 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3817 j = 0;
3818 for (; j + 8 <= block_size; j += 8) {
3819 _mm256_storeu_ps(
3820 &op[j],
3821 _mm256_fmadd_ps(
3822 vwgt,
3823 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
3824 reinterpret_cast<const __m128i*>(&ip[j])))),
3825 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
3826 _mm_prefetch(
3827 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3828 }
3829 for (; j < block_size; j++) {
3830 op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
3831 }
3832 }
3833 if (normalize_by_lengths && length) {
3834 float len_inv = 1.0f / length;
3835 __m256 vlen_inv = _mm256_set1_ps(len_inv);
3836 j = 0;
3837 for (; j + 8 <= block_size; j += 8) {
3838 _mm256_storeu_ps(
3839 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3840 }
3841 for (; j < block_size; j++) {
3842 op[j] = len_inv * op[j];
3843 }
3844 }
3845 }
3846 }
3847 return dataInd == index_size;
3848 }
EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3849 bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
3850 const int64_t block_size,
3851 const int64_t output_size,
3852 const int64_t index_size,
3853 const int64_t data_size,
3854 const uint8_t* input,
3855 const int* indices,
3856 const int* offsets,
3857 const float* weights,
3858 const float* scale_bias,
3859 bool normalize_by_lengths,
3860 float* out) {
3861 return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<false>(
3862 block_size,
3863 output_size,
3864 index_size,
3865 data_size,
3866 input,
3867 indices,
3868 offsets,
3869 weights,
3870 scale_bias,
3871 normalize_by_lengths,
3872 out);
3873 }
EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3874 bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
3875 const int64_t block_size,
3876 const int64_t output_size,
3877 const int64_t index_size,
3878 const int64_t data_size,
3879 const uint8_t* input,
3880 const int* indices,
3881 const int* offsets,
3882 const float* weights,
3883 const float* scale_bias,
3884 bool normalize_by_lengths,
3885 float* out) {
3886 return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<true>(
3887 block_size,
3888 output_size,
3889 index_size,
3890 data_size,
3891 input,
3892 indices,
3893 offsets,
3894 weights,
3895 scale_bias,
3896 normalize_by_lengths,
3897 out);
3898 }
3899
3900 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3901 static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(
3902 const int64_t block_size,
3903 const int64_t output_size,
3904 const int64_t index_size,
3905 const int64_t data_size,
3906 const uint8_t* input,
3907 const int64_t* indices,
3908 const int64_t* offsets,
3909 const float* weights,
3910 const float* scale_bias,
3911 bool normalize_by_lengths,
3912 float* out) {
3913 const int64_t prefdist_T0 = 16;
3914 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3915 const int64_t fused_block_size = block_size + 0;
3916 int64_t dataInd = 0;
3917 if (block_size == 128) {
3918 // unrolling 16 times
3919 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3920 float* op = &out[rangeIndex * block_size];
3921 __m256 vop0 = _mm256_setzero_ps();
3922 __m256 vop8 = _mm256_setzero_ps();
3923 __m256 vop16 = _mm256_setzero_ps();
3924 __m256 vop24 = _mm256_setzero_ps();
3925 __m256 vop32 = _mm256_setzero_ps();
3926 __m256 vop40 = _mm256_setzero_ps();
3927 __m256 vop48 = _mm256_setzero_ps();
3928 __m256 vop56 = _mm256_setzero_ps();
3929 __m256 vop64 = _mm256_setzero_ps();
3930 __m256 vop72 = _mm256_setzero_ps();
3931 __m256 vop80 = _mm256_setzero_ps();
3932 __m256 vop88 = _mm256_setzero_ps();
3933 __m256 vop96 = _mm256_setzero_ps();
3934 __m256 vop104 = _mm256_setzero_ps();
3935 __m256 vop112 = _mm256_setzero_ps();
3936 __m256 vop120 = _mm256_setzero_ps();
3937 if (dataInd != offsets[rangeIndex] - offsets[0]) {
3938 return false;
3939 }
3940 int64_t end_offset = offsets[rangeIndex + 1];
3941 int64_t length = end_offset - offsets[rangeIndex];
3942 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3943 ++dataInd) {
3944 const int64_t idx = indices[dataInd];
3945 if (idx < 0 || idx >= data_size) {
3946 return false;
3947 }
3948 float wgt = 1.f;
3949 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3950 float bio;
3951 if (weights) {
3952 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3953 }
3954 bio = wgt * scale_bias[2 * idx + 1];
3955 wgt = wgt * scale_bias[2 * idx];
3956 __m256 vbio = _mm256_set1_ps(bio);
3957 __m256 vwgt = _mm256_set1_ps(wgt);
3958 const uint8_t* ip = &input[idx * fused_block_size];
3959 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3960 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3961 ? (dataInd + prefdist_T0)
3962 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3963 : dataInd;
3964 const int64_t idx_pref_T0 = indices[next_T0];
3965 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3966 return false;
3967 }
3968 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3969 vop0 = _mm256_fmadd_ps(
3970 vwgt,
3971 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3972 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3973 _mm256_add_ps(vop0, vbio));
3974 _mm_prefetch(
3975 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3976 vop8 = _mm256_fmadd_ps(
3977 vwgt,
3978 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3979 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3980 _mm256_add_ps(vop8, vbio));
3981 // skip unnecessary prefetch of (&ip_next_T0[8])
3982 vop16 = _mm256_fmadd_ps(
3983 vwgt,
3984 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3985 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3986 _mm256_add_ps(vop16, vbio));
3987 // skip unnecessary prefetch of (&ip_next_T0[16])
3988 vop24 = _mm256_fmadd_ps(
3989 vwgt,
3990 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3991 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3992 _mm256_add_ps(vop24, vbio));
3993 // skip unnecessary prefetch of (&ip_next_T0[24])
3994 vop32 = _mm256_fmadd_ps(
3995 vwgt,
3996 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3997 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3998 _mm256_add_ps(vop32, vbio));
3999 // skip unnecessary prefetch of (&ip_next_T0[32])
4000 vop40 = _mm256_fmadd_ps(
4001 vwgt,
4002 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4003 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4004 _mm256_add_ps(vop40, vbio));
4005 // skip unnecessary prefetch of (&ip_next_T0[40])
4006 vop48 = _mm256_fmadd_ps(
4007 vwgt,
4008 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4009 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4010 _mm256_add_ps(vop48, vbio));
4011 // skip unnecessary prefetch of (&ip_next_T0[48])
4012 vop56 = _mm256_fmadd_ps(
4013 vwgt,
4014 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4015 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4016 _mm256_add_ps(vop56, vbio));
4017 // skip unnecessary prefetch of (&ip_next_T0[56])
4018 vop64 = _mm256_fmadd_ps(
4019 vwgt,
4020 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4021 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
4022 _mm256_add_ps(vop64, vbio));
4023 _mm_prefetch(
4024 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
4025 vop72 = _mm256_fmadd_ps(
4026 vwgt,
4027 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4028 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
4029 _mm256_add_ps(vop72, vbio));
4030 // skip unnecessary prefetch of (&ip_next_T0[72])
4031 vop80 = _mm256_fmadd_ps(
4032 vwgt,
4033 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4034 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
4035 _mm256_add_ps(vop80, vbio));
4036 // skip unnecessary prefetch of (&ip_next_T0[80])
4037 vop88 = _mm256_fmadd_ps(
4038 vwgt,
4039 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4040 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
4041 _mm256_add_ps(vop88, vbio));
4042 // skip unnecessary prefetch of (&ip_next_T0[88])
4043 vop96 = _mm256_fmadd_ps(
4044 vwgt,
4045 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4046 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
4047 _mm256_add_ps(vop96, vbio));
4048 // skip unnecessary prefetch of (&ip_next_T0[96])
4049 vop104 = _mm256_fmadd_ps(
4050 vwgt,
4051 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4052 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
4053 _mm256_add_ps(vop104, vbio));
4054 // skip unnecessary prefetch of (&ip_next_T0[104])
4055 vop112 = _mm256_fmadd_ps(
4056 vwgt,
4057 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4058 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
4059 _mm256_add_ps(vop112, vbio));
4060 // skip unnecessary prefetch of (&ip_next_T0[112])
4061 vop120 = _mm256_fmadd_ps(
4062 vwgt,
4063 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4064 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
4065 _mm256_add_ps(vop120, vbio));
4066 // skip unnecessary prefetch of (&ip_next_T0[120])
4067 }
4068 if (!normalize_by_lengths || length == 0) {
4069 _mm256_storeu_ps(&op[0], vop0);
4070 _mm256_storeu_ps(&op[8], vop8);
4071 _mm256_storeu_ps(&op[16], vop16);
4072 _mm256_storeu_ps(&op[24], vop24);
4073 _mm256_storeu_ps(&op[32], vop32);
4074 _mm256_storeu_ps(&op[40], vop40);
4075 _mm256_storeu_ps(&op[48], vop48);
4076 _mm256_storeu_ps(&op[56], vop56);
4077 _mm256_storeu_ps(&op[64], vop64);
4078 _mm256_storeu_ps(&op[72], vop72);
4079 _mm256_storeu_ps(&op[80], vop80);
4080 _mm256_storeu_ps(&op[88], vop88);
4081 _mm256_storeu_ps(&op[96], vop96);
4082 _mm256_storeu_ps(&op[104], vop104);
4083 _mm256_storeu_ps(&op[112], vop112);
4084 _mm256_storeu_ps(&op[120], vop120);
4085 } else {
4086 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4087 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4088 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4089 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4090 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4091 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4092 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4093 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4094 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4095 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
4096 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
4097 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
4098 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
4099 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
4100 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
4101 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
4102 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
4103 }
4104 }
4105 } else if (block_size == 64) {
4106 // unrolling 8 times
4107 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4108 float* op = &out[rangeIndex * block_size];
4109 __m256 vop0 = _mm256_setzero_ps();
4110 __m256 vop8 = _mm256_setzero_ps();
4111 __m256 vop16 = _mm256_setzero_ps();
4112 __m256 vop24 = _mm256_setzero_ps();
4113 __m256 vop32 = _mm256_setzero_ps();
4114 __m256 vop40 = _mm256_setzero_ps();
4115 __m256 vop48 = _mm256_setzero_ps();
4116 __m256 vop56 = _mm256_setzero_ps();
4117 if (dataInd != offsets[rangeIndex] - offsets[0]) {
4118 return false;
4119 }
4120 int64_t end_offset = offsets[rangeIndex + 1];
4121 int64_t length = end_offset - offsets[rangeIndex];
4122 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4123 ++dataInd) {
4124 const int64_t idx = indices[dataInd];
4125 if (idx < 0 || idx >= data_size) {
4126 return false;
4127 }
4128 float wgt = 1.f;
4129 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4130 float bio;
4131 if (weights) {
4132 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4133 }
4134 bio = wgt * scale_bias[2 * idx + 1];
4135 wgt = wgt * scale_bias[2 * idx];
4136 __m256 vbio = _mm256_set1_ps(bio);
4137 __m256 vwgt = _mm256_set1_ps(wgt);
4138 const uint8_t* ip = &input[idx * fused_block_size];
4139 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4140 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4141 ? (dataInd + prefdist_T0)
4142 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4143 : dataInd;
4144 const int64_t idx_pref_T0 = indices[next_T0];
4145 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4146 return false;
4147 }
4148 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4149 vop0 = _mm256_fmadd_ps(
4150 vwgt,
4151 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4152 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4153 _mm256_add_ps(vop0, vbio));
4154 _mm_prefetch(
4155 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4156 vop8 = _mm256_fmadd_ps(
4157 vwgt,
4158 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4159 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4160 _mm256_add_ps(vop8, vbio));
4161 // skip unnecessary prefetch of (&ip_next_T0[8])
4162 vop16 = _mm256_fmadd_ps(
4163 vwgt,
4164 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4165 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4166 _mm256_add_ps(vop16, vbio));
4167 // skip unnecessary prefetch of (&ip_next_T0[16])
4168 vop24 = _mm256_fmadd_ps(
4169 vwgt,
4170 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4171 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4172 _mm256_add_ps(vop24, vbio));
4173 // skip unnecessary prefetch of (&ip_next_T0[24])
4174 vop32 = _mm256_fmadd_ps(
4175 vwgt,
4176 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4177 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
4178 _mm256_add_ps(vop32, vbio));
4179 // skip unnecessary prefetch of (&ip_next_T0[32])
4180 vop40 = _mm256_fmadd_ps(
4181 vwgt,
4182 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4183 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4184 _mm256_add_ps(vop40, vbio));
4185 // skip unnecessary prefetch of (&ip_next_T0[40])
4186 vop48 = _mm256_fmadd_ps(
4187 vwgt,
4188 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4189 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4190 _mm256_add_ps(vop48, vbio));
4191 // skip unnecessary prefetch of (&ip_next_T0[48])
4192 vop56 = _mm256_fmadd_ps(
4193 vwgt,
4194 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4195 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4196 _mm256_add_ps(vop56, vbio));
4197 // skip unnecessary prefetch of (&ip_next_T0[56])
4198 }
4199 if (!normalize_by_lengths || length == 0) {
4200 _mm256_storeu_ps(&op[0], vop0);
4201 _mm256_storeu_ps(&op[8], vop8);
4202 _mm256_storeu_ps(&op[16], vop16);
4203 _mm256_storeu_ps(&op[24], vop24);
4204 _mm256_storeu_ps(&op[32], vop32);
4205 _mm256_storeu_ps(&op[40], vop40);
4206 _mm256_storeu_ps(&op[48], vop48);
4207 _mm256_storeu_ps(&op[56], vop56);
4208 } else {
4209 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4210 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4211 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4212 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4213 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4214 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4215 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4216 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4217 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4218 }
4219 }
4220 } else if (block_size == 32) {
4221 // unrolling 4 times
4222 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4223 float* op = &out[rangeIndex * block_size];
4224 __m256 vop0 = _mm256_setzero_ps();
4225 __m256 vop8 = _mm256_setzero_ps();
4226 __m256 vop16 = _mm256_setzero_ps();
4227 __m256 vop24 = _mm256_setzero_ps();
4228 if (dataInd != offsets[rangeIndex] - offsets[0]) {
4229 return false;
4230 }
4231 int64_t end_offset = offsets[rangeIndex + 1];
4232 int64_t length = end_offset - offsets[rangeIndex];
4233 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4234 ++dataInd) {
4235 const int64_t idx = indices[dataInd];
4236 if (idx < 0 || idx >= data_size) {
4237 return false;
4238 }
4239 float wgt = 1.f;
4240 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4241 float bio;
4242 if (weights) {
4243 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4244 }
4245 bio = wgt * scale_bias[2 * idx + 1];
4246 wgt = wgt * scale_bias[2 * idx];
4247 __m256 vbio = _mm256_set1_ps(bio);
4248 __m256 vwgt = _mm256_set1_ps(wgt);
4249 const uint8_t* ip = &input[idx * fused_block_size];
4250 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4251 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4252 ? (dataInd + prefdist_T0)
4253 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4254 : dataInd;
4255 const int64_t idx_pref_T0 = indices[next_T0];
4256 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4257 return false;
4258 }
4259 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4260 vop0 = _mm256_fmadd_ps(
4261 vwgt,
4262 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4263 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4264 _mm256_add_ps(vop0, vbio));
4265 _mm_prefetch(
4266 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4267 vop8 = _mm256_fmadd_ps(
4268 vwgt,
4269 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4270 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4271 _mm256_add_ps(vop8, vbio));
4272 // skip unnecessary prefetch of (&ip_next_T0[8])
4273 vop16 = _mm256_fmadd_ps(
4274 vwgt,
4275 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4276 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4277 _mm256_add_ps(vop16, vbio));
4278 // skip unnecessary prefetch of (&ip_next_T0[16])
4279 vop24 = _mm256_fmadd_ps(
4280 vwgt,
4281 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4282 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4283 _mm256_add_ps(vop24, vbio));
4284 // skip unnecessary prefetch of (&ip_next_T0[24])
4285 }
4286 if (!normalize_by_lengths || length == 0) {
4287 _mm256_storeu_ps(&op[0], vop0);
4288 _mm256_storeu_ps(&op[8], vop8);
4289 _mm256_storeu_ps(&op[16], vop16);
4290 _mm256_storeu_ps(&op[24], vop24);
4291 } else {
4292 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4293 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4294 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4295 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4296 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4297 }
4298 }
4299 } else if (block_size == 16) {
4300 // unrolling 2 times
4301 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4302 float* op = &out[rangeIndex * block_size];
4303 __m256 vop0 = _mm256_setzero_ps();
4304 __m256 vop8 = _mm256_setzero_ps();
4305 if (dataInd != offsets[rangeIndex] - offsets[0]) {
4306 return false;
4307 }
4308 int64_t end_offset = offsets[rangeIndex + 1];
4309 int64_t length = end_offset - offsets[rangeIndex];
4310 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4311 ++dataInd) {
4312 const int64_t idx = indices[dataInd];
4313 if (idx < 0 || idx >= data_size) {
4314 return false;
4315 }
4316 float wgt = 1.f;
4317 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4318 float bio;
4319 if (weights) {
4320 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4321 }
4322 bio = wgt * scale_bias[2 * idx + 1];
4323 wgt = wgt * scale_bias[2 * idx];
4324 __m256 vbio = _mm256_set1_ps(bio);
4325 __m256 vwgt = _mm256_set1_ps(wgt);
4326 const uint8_t* ip = &input[idx * fused_block_size];
4327 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4328 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4329 ? (dataInd + prefdist_T0)
4330 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4331 : dataInd;
4332 const int64_t idx_pref_T0 = indices[next_T0];
4333 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4334 return false;
4335 }
4336 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4337 vop0 = _mm256_fmadd_ps(
4338 vwgt,
4339 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4340 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4341 _mm256_add_ps(vop0, vbio));
4342 _mm_prefetch(
4343 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4344 vop8 = _mm256_fmadd_ps(
4345 vwgt,
4346 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4347 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4348 _mm256_add_ps(vop8, vbio));
4349 // skip unnecessary prefetch of (&ip_next_T0[8])
4350 }
4351 if (!normalize_by_lengths || length == 0) {
4352 _mm256_storeu_ps(&op[0], vop0);
4353 _mm256_storeu_ps(&op[8], vop8);
4354 } else {
4355 __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4356 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4357 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4358 }
4359 }
4360 } else {
4361 // generic code
4362 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
4363 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4364 float* op = &out[rangeIndex * block_size];
4365 int64_t j = 0;
4366 for (; j + 8 <= block_size; j += 8) {
4367 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
4368 }
4369 for (; j < block_size; j++) {
4370 op[j] = 0.0f;
4371 }
4372 if (dataInd != offsets[rangeIndex] - offsets[0]) {
4373 return false;
4374 }
4375 int64_t end_offset = offsets[rangeIndex + 1];
4376 int64_t length = end_offset - offsets[rangeIndex];
4377 for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4378 ++dataInd) {
4379 const int64_t idx = indices[dataInd];
4380 if (idx < 0 || idx >= data_size) {
4381 return false;
4382 }
4383 float wgt = 1.f;
4384 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4385 float bio;
4386 if (weights) {
4387 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4388 }
4389 bio = wgt * scale_bias[2 * idx + 1];
4390 wgt = wgt * scale_bias[2 * idx];
4391 __m256 vbio = _mm256_set1_ps(bio);
4392 __m256 vwgt = _mm256_set1_ps(wgt);
4393 const uint8_t* ip = &input[idx * fused_block_size];
4394 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4395 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4396 ? (dataInd + prefdist_T0)
4397 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4398 : dataInd;
4399 const int64_t idx_pref_T0 = indices[next_T0];
4400 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4401 return false;
4402 }
4403 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4404 j = 0;
4405 for (; j + 8 <= block_size; j += 8) {
4406 _mm256_storeu_ps(
4407 &op[j],
4408 _mm256_fmadd_ps(
4409 vwgt,
4410 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
4411 reinterpret_cast<const __m128i*>(&ip[j])))),
4412 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
4413 _mm_prefetch(
4414 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
4415 }
4416 for (; j < block_size; j++) {
4417 op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
4418 }
4419 }
4420 if (normalize_by_lengths && length) {
4421 float len_inv = 1.0f / length;
4422 __m256 vlen_inv = _mm256_set1_ps(len_inv);
4423 j = 0;
4424 for (; j + 8 <= block_size; j += 8) {
4425 _mm256_storeu_ps(
4426 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
4427 }
4428 for (; j < block_size; j++) {
4429 op[j] = len_inv * op[j];
4430 }
4431 }
4432 }
4433 }
4434 return dataInd == index_size;
4435 }
EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)4436 bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(
4437 const int64_t block_size,
4438 const int64_t output_size,
4439 const int64_t index_size,
4440 const int64_t data_size,
4441 const uint8_t* input,
4442 const int64_t* indices,
4443 const int64_t* offsets,
4444 const float* weights,
4445 const float* scale_bias,
4446 bool normalize_by_lengths,
4447 float* out) {
4448 return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<false>(
4449 block_size,
4450 output_size,
4451 index_size,
4452 data_size,
4453 input,
4454 indices,
4455 offsets,
4456 weights,
4457 scale_bias,
4458 normalize_by_lengths,
4459 out);
4460 }
EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)4461 bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(
4462 const int64_t block_size,
4463 const int64_t output_size,
4464 const int64_t index_size,
4465 const int64_t data_size,
4466 const uint8_t* input,
4467 const int64_t* indices,
4468 const int64_t* offsets,
4469 const float* weights,
4470 const float* scale_bias,
4471 bool normalize_by_lengths,
4472 float* out) {
4473 return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<true>(
4474 block_size,
4475 output_size,
4476 index_size,
4477 data_size,
4478 input,
4479 indices,
4480 offsets,
4481 weights,
4482 scale_bias,
4483 normalize_by_lengths,
4484 out);
4485 }
4486
4487 } // namespace caffe2
4488