xref: /aosp_15_r20/external/pytorch/caffe2/perfkernels/embedding_lookup_idx_avx2.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //// --------------------------
2 //// ATTENTION:
3 //// THIS CODE IS AUTOGENERATED
4 //// BY hp_emblookup_codegen.py
5 //// DO NOT MODIFY!!!
6 //// --------------------------
7 
8 #include <c10/util/Half.h>
9 #include <c10/util/BFloat16.h>
10 #include <immintrin.h>
11 namespace caffe2 {
12 
13 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_float_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)14 static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(
15     const int64_t block_size,
16     const int64_t output_size,
17     const int64_t index_size,
18     const int64_t data_size,
19     const float* input,
20     const int* indices,
21     const int* offsets,
22     const float* weights,
23     const float* scale_bias,
24     bool normalize_by_lengths,
25     float* out) {
26   const int prefdist_T0 = 16;
27   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
28   const int fused_block_size = block_size + 0;
29   int64_t dataInd = 0;
30   if (block_size == 128) {
31     // unrolling 16 times
32     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
33       float* op = &out[rangeIndex * block_size];
34       __m256 vop0 = _mm256_setzero_ps();
35       __m256 vop8 = _mm256_setzero_ps();
36       __m256 vop16 = _mm256_setzero_ps();
37       __m256 vop24 = _mm256_setzero_ps();
38       __m256 vop32 = _mm256_setzero_ps();
39       __m256 vop40 = _mm256_setzero_ps();
40       __m256 vop48 = _mm256_setzero_ps();
41       __m256 vop56 = _mm256_setzero_ps();
42       __m256 vop64 = _mm256_setzero_ps();
43       __m256 vop72 = _mm256_setzero_ps();
44       __m256 vop80 = _mm256_setzero_ps();
45       __m256 vop88 = _mm256_setzero_ps();
46       __m256 vop96 = _mm256_setzero_ps();
47       __m256 vop104 = _mm256_setzero_ps();
48       __m256 vop112 = _mm256_setzero_ps();
49       __m256 vop120 = _mm256_setzero_ps();
50       if (dataInd != offsets[rangeIndex] - offsets[0]) {
51         return false;
52       }
53       int64_t end_offset = offsets[rangeIndex + 1];
54       int64_t length = end_offset - offsets[rangeIndex];
55       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
56            ++dataInd) {
57         const int idx = indices[dataInd];
58         if (idx < 0 || idx >= data_size) {
59           return false;
60         }
61         float wgt = 1.f;
62         if (weights) {
63           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
64         }
65         __m256 vwgt = _mm256_set1_ps(wgt);
66         const float* ip = &input[idx * fused_block_size];
67         const int next_T0 = (dataInd < index_size - prefdist_T0)
68             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
69             ? (dataInd + prefdist_T0)
70             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
71             : dataInd;
72         const int idx_pref_T0 = indices[next_T0];
73         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
74           return false;
75         }
76         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
77         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
78         _mm_prefetch(
79             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
80         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
81         // skip unnecessary prefetch of (&ip_next_T0[8])
82         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
83         _mm_prefetch(
84             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
85         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
86         // skip unnecessary prefetch of (&ip_next_T0[24])
87         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
88         _mm_prefetch(
89             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
90         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
91         // skip unnecessary prefetch of (&ip_next_T0[40])
92         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
93         _mm_prefetch(
94             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
95         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
96         // skip unnecessary prefetch of (&ip_next_T0[56])
97         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
98         _mm_prefetch(
99             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
100         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
101         // skip unnecessary prefetch of (&ip_next_T0[72])
102         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
103         _mm_prefetch(
104             reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
105         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
106         // skip unnecessary prefetch of (&ip_next_T0[88])
107         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
108         _mm_prefetch(
109             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
110         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
111         // skip unnecessary prefetch of (&ip_next_T0[104])
112         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
113         _mm_prefetch(
114             reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
115         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
116         // skip unnecessary prefetch of (&ip_next_T0[120])
117       }
118       if (!normalize_by_lengths || length == 0) {
119         _mm256_storeu_ps(&op[0], vop0);
120         _mm256_storeu_ps(&op[8], vop8);
121         _mm256_storeu_ps(&op[16], vop16);
122         _mm256_storeu_ps(&op[24], vop24);
123         _mm256_storeu_ps(&op[32], vop32);
124         _mm256_storeu_ps(&op[40], vop40);
125         _mm256_storeu_ps(&op[48], vop48);
126         _mm256_storeu_ps(&op[56], vop56);
127         _mm256_storeu_ps(&op[64], vop64);
128         _mm256_storeu_ps(&op[72], vop72);
129         _mm256_storeu_ps(&op[80], vop80);
130         _mm256_storeu_ps(&op[88], vop88);
131         _mm256_storeu_ps(&op[96], vop96);
132         _mm256_storeu_ps(&op[104], vop104);
133         _mm256_storeu_ps(&op[112], vop112);
134         _mm256_storeu_ps(&op[120], vop120);
135       } else {
136         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
137         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
138         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
139         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
140         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
141         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
142         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
143         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
144         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
145         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
146         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
147         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
148         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
149         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
150         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
151         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
152         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
153       }
154     }
155   } else if (block_size == 64) {
156     // unrolling 8 times
157     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
158       float* op = &out[rangeIndex * block_size];
159       __m256 vop0 = _mm256_setzero_ps();
160       __m256 vop8 = _mm256_setzero_ps();
161       __m256 vop16 = _mm256_setzero_ps();
162       __m256 vop24 = _mm256_setzero_ps();
163       __m256 vop32 = _mm256_setzero_ps();
164       __m256 vop40 = _mm256_setzero_ps();
165       __m256 vop48 = _mm256_setzero_ps();
166       __m256 vop56 = _mm256_setzero_ps();
167       if (dataInd != offsets[rangeIndex] - offsets[0]) {
168         return false;
169       }
170       int64_t end_offset = offsets[rangeIndex + 1];
171       int64_t length = end_offset - offsets[rangeIndex];
172       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
173            ++dataInd) {
174         const int idx = indices[dataInd];
175         if (idx < 0 || idx >= data_size) {
176           return false;
177         }
178         float wgt = 1.f;
179         if (weights) {
180           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
181         }
182         __m256 vwgt = _mm256_set1_ps(wgt);
183         const float* ip = &input[idx * fused_block_size];
184         const int next_T0 = (dataInd < index_size - prefdist_T0)
185             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
186             ? (dataInd + prefdist_T0)
187             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
188             : dataInd;
189         const int idx_pref_T0 = indices[next_T0];
190         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
191           return false;
192         }
193         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
194         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
195         _mm_prefetch(
196             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
197         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
198         // skip unnecessary prefetch of (&ip_next_T0[8])
199         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
200         _mm_prefetch(
201             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
202         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
203         // skip unnecessary prefetch of (&ip_next_T0[24])
204         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
205         _mm_prefetch(
206             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
207         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
208         // skip unnecessary prefetch of (&ip_next_T0[40])
209         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
210         _mm_prefetch(
211             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
212         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
213         // skip unnecessary prefetch of (&ip_next_T0[56])
214       }
215       if (!normalize_by_lengths || length == 0) {
216         _mm256_storeu_ps(&op[0], vop0);
217         _mm256_storeu_ps(&op[8], vop8);
218         _mm256_storeu_ps(&op[16], vop16);
219         _mm256_storeu_ps(&op[24], vop24);
220         _mm256_storeu_ps(&op[32], vop32);
221         _mm256_storeu_ps(&op[40], vop40);
222         _mm256_storeu_ps(&op[48], vop48);
223         _mm256_storeu_ps(&op[56], vop56);
224       } else {
225         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
226         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
227         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
228         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
229         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
230         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
231         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
232         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
233         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
234       }
235     }
236   } else if (block_size == 32) {
237     // unrolling 4 times
238     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
239       float* op = &out[rangeIndex * block_size];
240       __m256 vop0 = _mm256_setzero_ps();
241       __m256 vop8 = _mm256_setzero_ps();
242       __m256 vop16 = _mm256_setzero_ps();
243       __m256 vop24 = _mm256_setzero_ps();
244       if (dataInd != offsets[rangeIndex] - offsets[0]) {
245         return false;
246       }
247       int64_t end_offset = offsets[rangeIndex + 1];
248       int64_t length = end_offset - offsets[rangeIndex];
249       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
250            ++dataInd) {
251         const int idx = indices[dataInd];
252         if (idx < 0 || idx >= data_size) {
253           return false;
254         }
255         float wgt = 1.f;
256         if (weights) {
257           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
258         }
259         __m256 vwgt = _mm256_set1_ps(wgt);
260         const float* ip = &input[idx * fused_block_size];
261         const int next_T0 = (dataInd < index_size - prefdist_T0)
262             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
263             ? (dataInd + prefdist_T0)
264             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
265             : dataInd;
266         const int idx_pref_T0 = indices[next_T0];
267         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
268           return false;
269         }
270         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
271         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
272         _mm_prefetch(
273             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
274         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
275         // skip unnecessary prefetch of (&ip_next_T0[8])
276         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
277         _mm_prefetch(
278             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
279         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
280         // skip unnecessary prefetch of (&ip_next_T0[24])
281       }
282       if (!normalize_by_lengths || length == 0) {
283         _mm256_storeu_ps(&op[0], vop0);
284         _mm256_storeu_ps(&op[8], vop8);
285         _mm256_storeu_ps(&op[16], vop16);
286         _mm256_storeu_ps(&op[24], vop24);
287       } else {
288         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
289         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
290         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
291         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
292         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
293       }
294     }
295   } else if (block_size == 16) {
296     // unrolling 2 times
297     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
298       float* op = &out[rangeIndex * block_size];
299       __m256 vop0 = _mm256_setzero_ps();
300       __m256 vop8 = _mm256_setzero_ps();
301       if (dataInd != offsets[rangeIndex] - offsets[0]) {
302         return false;
303       }
304       int64_t end_offset = offsets[rangeIndex + 1];
305       int64_t length = end_offset - offsets[rangeIndex];
306       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
307            ++dataInd) {
308         const int idx = indices[dataInd];
309         if (idx < 0 || idx >= data_size) {
310           return false;
311         }
312         float wgt = 1.f;
313         if (weights) {
314           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
315         }
316         __m256 vwgt = _mm256_set1_ps(wgt);
317         const float* ip = &input[idx * fused_block_size];
318         const int next_T0 = (dataInd < index_size - prefdist_T0)
319             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
320             ? (dataInd + prefdist_T0)
321             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
322             : dataInd;
323         const int idx_pref_T0 = indices[next_T0];
324         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
325           return false;
326         }
327         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
328         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
329         _mm_prefetch(
330             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
331         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
332         // skip unnecessary prefetch of (&ip_next_T0[8])
333       }
334       if (!normalize_by_lengths || length == 0) {
335         _mm256_storeu_ps(&op[0], vop0);
336         _mm256_storeu_ps(&op[8], vop8);
337       } else {
338         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
339         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
340         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
341       }
342     }
343   } else {
344     // generic code
345     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
346     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
347       float* op = &out[rangeIndex * block_size];
348       int64_t j = 0;
349       for (; j + 8 <= block_size; j += 8) {
350         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
351       }
352       for (; j < block_size; j++) {
353         op[j] = 0.0f;
354       }
355       if (dataInd != offsets[rangeIndex] - offsets[0]) {
356         return false;
357       }
358       int64_t end_offset = offsets[rangeIndex + 1];
359       int64_t length = end_offset - offsets[rangeIndex];
360       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
361            ++dataInd) {
362         const int idx = indices[dataInd];
363         if (idx < 0 || idx >= data_size) {
364           return false;
365         }
366         float wgt = 1.f;
367         if (weights) {
368           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
369         }
370         __m256 vwgt = _mm256_set1_ps(wgt);
371         const float* ip = &input[idx * fused_block_size];
372         const int next_T0 = (dataInd < index_size - prefdist_T0)
373             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
374             ? (dataInd + prefdist_T0)
375             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
376             : dataInd;
377         const int idx_pref_T0 = indices[next_T0];
378         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
379           return false;
380         }
381         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
382         j = 0;
383         for (; j + 8 <= block_size; j += 8) {
384           _mm256_storeu_ps(
385               &op[j],
386               _mm256_fmadd_ps(
387                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
388           _mm_prefetch(
389               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
390         }
391         for (; j < block_size; j++) {
392           op[j] = std::fma(wgt, ip[j], op[j]);
393         }
394       }
395       if (normalize_by_lengths && length) {
396         float len_inv = 1.0f / length;
397         __m256 vlen_inv = _mm256_set1_ps(len_inv);
398         j = 0;
399         for (; j + 8 <= block_size; j += 8) {
400           _mm256_storeu_ps(
401               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
402         }
403         for (; j < block_size; j++) {
404           op[j] = len_inv * op[j];
405         }
406       }
407     }
408   }
409   return dataInd == index_size;
410 }
EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)411 bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
412     const int64_t block_size,
413     const int64_t output_size,
414     const int64_t index_size,
415     const int64_t data_size,
416     const float* input,
417     const int* indices,
418     const int* offsets,
419     const float* weights,
420     const float* scale_bias,
421     bool normalize_by_lengths,
422     float* out) {
423   return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<false>(
424       block_size,
425       output_size,
426       index_size,
427       data_size,
428       input,
429       indices,
430       offsets,
431       weights,
432       scale_bias,
433       normalize_by_lengths,
434       out);
435 }
EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)436 bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
437     const int64_t block_size,
438     const int64_t output_size,
439     const int64_t index_size,
440     const int64_t data_size,
441     const float* input,
442     const int* indices,
443     const int* offsets,
444     const float* weights,
445     const float* scale_bias,
446     bool normalize_by_lengths,
447     float* out) {
448   return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<true>(
449       block_size,
450       output_size,
451       index_size,
452       data_size,
453       input,
454       indices,
455       offsets,
456       weights,
457       scale_bias,
458       normalize_by_lengths,
459       out);
460 }
461 
462 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_float_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)463 static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma(
464     const int64_t block_size,
465     const int64_t output_size,
466     const int64_t index_size,
467     const int64_t data_size,
468     const float* input,
469     const int64_t* indices,
470     const int64_t* offsets,
471     const float* weights,
472     const float* scale_bias,
473     bool normalize_by_lengths,
474     float* out) {
475   const int64_t prefdist_T0 = 16;
476   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
477   const int64_t fused_block_size = block_size + 0;
478   int64_t dataInd = 0;
479   if (block_size == 128) {
480     // unrolling 16 times
481     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
482       float* op = &out[rangeIndex * block_size];
483       __m256 vop0 = _mm256_setzero_ps();
484       __m256 vop8 = _mm256_setzero_ps();
485       __m256 vop16 = _mm256_setzero_ps();
486       __m256 vop24 = _mm256_setzero_ps();
487       __m256 vop32 = _mm256_setzero_ps();
488       __m256 vop40 = _mm256_setzero_ps();
489       __m256 vop48 = _mm256_setzero_ps();
490       __m256 vop56 = _mm256_setzero_ps();
491       __m256 vop64 = _mm256_setzero_ps();
492       __m256 vop72 = _mm256_setzero_ps();
493       __m256 vop80 = _mm256_setzero_ps();
494       __m256 vop88 = _mm256_setzero_ps();
495       __m256 vop96 = _mm256_setzero_ps();
496       __m256 vop104 = _mm256_setzero_ps();
497       __m256 vop112 = _mm256_setzero_ps();
498       __m256 vop120 = _mm256_setzero_ps();
499       if (dataInd != offsets[rangeIndex] - offsets[0]) {
500         return false;
501       }
502       int64_t end_offset = offsets[rangeIndex + 1];
503       int64_t length = end_offset - offsets[rangeIndex];
504       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
505            ++dataInd) {
506         const int64_t idx = indices[dataInd];
507         if (idx < 0 || idx >= data_size) {
508           return false;
509         }
510         float wgt = 1.f;
511         if (weights) {
512           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
513         }
514         __m256 vwgt = _mm256_set1_ps(wgt);
515         const float* ip = &input[idx * fused_block_size];
516         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
517             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
518             ? (dataInd + prefdist_T0)
519             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
520             : dataInd;
521         const int64_t idx_pref_T0 = indices[next_T0];
522         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
523           return false;
524         }
525         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
526         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
527         _mm_prefetch(
528             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
529         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
530         // skip unnecessary prefetch of (&ip_next_T0[8])
531         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
532         _mm_prefetch(
533             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
534         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
535         // skip unnecessary prefetch of (&ip_next_T0[24])
536         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
537         _mm_prefetch(
538             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
539         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
540         // skip unnecessary prefetch of (&ip_next_T0[40])
541         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
542         _mm_prefetch(
543             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
544         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
545         // skip unnecessary prefetch of (&ip_next_T0[56])
546         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
547         _mm_prefetch(
548             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
549         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
550         // skip unnecessary prefetch of (&ip_next_T0[72])
551         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
552         _mm_prefetch(
553             reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
554         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
555         // skip unnecessary prefetch of (&ip_next_T0[88])
556         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
557         _mm_prefetch(
558             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
559         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
560         // skip unnecessary prefetch of (&ip_next_T0[104])
561         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
562         _mm_prefetch(
563             reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
564         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
565         // skip unnecessary prefetch of (&ip_next_T0[120])
566       }
567       if (!normalize_by_lengths || length == 0) {
568         _mm256_storeu_ps(&op[0], vop0);
569         _mm256_storeu_ps(&op[8], vop8);
570         _mm256_storeu_ps(&op[16], vop16);
571         _mm256_storeu_ps(&op[24], vop24);
572         _mm256_storeu_ps(&op[32], vop32);
573         _mm256_storeu_ps(&op[40], vop40);
574         _mm256_storeu_ps(&op[48], vop48);
575         _mm256_storeu_ps(&op[56], vop56);
576         _mm256_storeu_ps(&op[64], vop64);
577         _mm256_storeu_ps(&op[72], vop72);
578         _mm256_storeu_ps(&op[80], vop80);
579         _mm256_storeu_ps(&op[88], vop88);
580         _mm256_storeu_ps(&op[96], vop96);
581         _mm256_storeu_ps(&op[104], vop104);
582         _mm256_storeu_ps(&op[112], vop112);
583         _mm256_storeu_ps(&op[120], vop120);
584       } else {
585         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
586         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
587         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
588         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
589         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
590         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
591         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
592         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
593         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
594         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
595         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
596         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
597         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
598         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
599         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
600         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
601         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
602       }
603     }
604   } else if (block_size == 64) {
605     // unrolling 8 times
606     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
607       float* op = &out[rangeIndex * block_size];
608       __m256 vop0 = _mm256_setzero_ps();
609       __m256 vop8 = _mm256_setzero_ps();
610       __m256 vop16 = _mm256_setzero_ps();
611       __m256 vop24 = _mm256_setzero_ps();
612       __m256 vop32 = _mm256_setzero_ps();
613       __m256 vop40 = _mm256_setzero_ps();
614       __m256 vop48 = _mm256_setzero_ps();
615       __m256 vop56 = _mm256_setzero_ps();
616       if (dataInd != offsets[rangeIndex] - offsets[0]) {
617         return false;
618       }
619       int64_t end_offset = offsets[rangeIndex + 1];
620       int64_t length = end_offset - offsets[rangeIndex];
621       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
622            ++dataInd) {
623         const int64_t idx = indices[dataInd];
624         if (idx < 0 || idx >= data_size) {
625           return false;
626         }
627         float wgt = 1.f;
628         if (weights) {
629           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
630         }
631         __m256 vwgt = _mm256_set1_ps(wgt);
632         const float* ip = &input[idx * fused_block_size];
633         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
634             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
635             ? (dataInd + prefdist_T0)
636             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
637             : dataInd;
638         const int64_t idx_pref_T0 = indices[next_T0];
639         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
640           return false;
641         }
642         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
643         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
644         _mm_prefetch(
645             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
646         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
647         // skip unnecessary prefetch of (&ip_next_T0[8])
648         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
649         _mm_prefetch(
650             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
651         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
652         // skip unnecessary prefetch of (&ip_next_T0[24])
653         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
654         _mm_prefetch(
655             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
656         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
657         // skip unnecessary prefetch of (&ip_next_T0[40])
658         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
659         _mm_prefetch(
660             reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
661         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
662         // skip unnecessary prefetch of (&ip_next_T0[56])
663       }
664       if (!normalize_by_lengths || length == 0) {
665         _mm256_storeu_ps(&op[0], vop0);
666         _mm256_storeu_ps(&op[8], vop8);
667         _mm256_storeu_ps(&op[16], vop16);
668         _mm256_storeu_ps(&op[24], vop24);
669         _mm256_storeu_ps(&op[32], vop32);
670         _mm256_storeu_ps(&op[40], vop40);
671         _mm256_storeu_ps(&op[48], vop48);
672         _mm256_storeu_ps(&op[56], vop56);
673       } else {
674         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
675         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
676         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
677         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
678         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
679         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
680         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
681         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
682         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
683       }
684     }
685   } else if (block_size == 32) {
686     // unrolling 4 times
687     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
688       float* op = &out[rangeIndex * block_size];
689       __m256 vop0 = _mm256_setzero_ps();
690       __m256 vop8 = _mm256_setzero_ps();
691       __m256 vop16 = _mm256_setzero_ps();
692       __m256 vop24 = _mm256_setzero_ps();
693       if (dataInd != offsets[rangeIndex] - offsets[0]) {
694         return false;
695       }
696       int64_t end_offset = offsets[rangeIndex + 1];
697       int64_t length = end_offset - offsets[rangeIndex];
698       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
699            ++dataInd) {
700         const int64_t idx = indices[dataInd];
701         if (idx < 0 || idx >= data_size) {
702           return false;
703         }
704         float wgt = 1.f;
705         if (weights) {
706           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
707         }
708         __m256 vwgt = _mm256_set1_ps(wgt);
709         const float* ip = &input[idx * fused_block_size];
710         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
711             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
712             ? (dataInd + prefdist_T0)
713             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
714             : dataInd;
715         const int64_t idx_pref_T0 = indices[next_T0];
716         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
717           return false;
718         }
719         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
720         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
721         _mm_prefetch(
722             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
723         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
724         // skip unnecessary prefetch of (&ip_next_T0[8])
725         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
726         _mm_prefetch(
727             reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
728         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
729         // skip unnecessary prefetch of (&ip_next_T0[24])
730       }
731       if (!normalize_by_lengths || length == 0) {
732         _mm256_storeu_ps(&op[0], vop0);
733         _mm256_storeu_ps(&op[8], vop8);
734         _mm256_storeu_ps(&op[16], vop16);
735         _mm256_storeu_ps(&op[24], vop24);
736       } else {
737         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
738         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
739         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
740         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
741         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
742       }
743     }
744   } else if (block_size == 16) {
745     // unrolling 2 times
746     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
747       float* op = &out[rangeIndex * block_size];
748       __m256 vop0 = _mm256_setzero_ps();
749       __m256 vop8 = _mm256_setzero_ps();
750       if (dataInd != offsets[rangeIndex] - offsets[0]) {
751         return false;
752       }
753       int64_t end_offset = offsets[rangeIndex + 1];
754       int64_t length = end_offset - offsets[rangeIndex];
755       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
756            ++dataInd) {
757         const int64_t idx = indices[dataInd];
758         if (idx < 0 || idx >= data_size) {
759           return false;
760         }
761         float wgt = 1.f;
762         if (weights) {
763           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
764         }
765         __m256 vwgt = _mm256_set1_ps(wgt);
766         const float* ip = &input[idx * fused_block_size];
767         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
768             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
769             ? (dataInd + prefdist_T0)
770             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
771             : dataInd;
772         const int64_t idx_pref_T0 = indices[next_T0];
773         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
774           return false;
775         }
776         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
777         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
778         _mm_prefetch(
779             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
780         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
781         // skip unnecessary prefetch of (&ip_next_T0[8])
782       }
783       if (!normalize_by_lengths || length == 0) {
784         _mm256_storeu_ps(&op[0], vop0);
785         _mm256_storeu_ps(&op[8], vop8);
786       } else {
787         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
788         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
789         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
790       }
791     }
792   } else {
793     // generic code
794     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
795     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
796       float* op = &out[rangeIndex * block_size];
797       int64_t j = 0;
798       for (; j + 8 <= block_size; j += 8) {
799         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
800       }
801       for (; j < block_size; j++) {
802         op[j] = 0.0f;
803       }
804       if (dataInd != offsets[rangeIndex] - offsets[0]) {
805         return false;
806       }
807       int64_t end_offset = offsets[rangeIndex + 1];
808       int64_t length = end_offset - offsets[rangeIndex];
809       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
810            ++dataInd) {
811         const int64_t idx = indices[dataInd];
812         if (idx < 0 || idx >= data_size) {
813           return false;
814         }
815         float wgt = 1.f;
816         if (weights) {
817           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
818         }
819         __m256 vwgt = _mm256_set1_ps(wgt);
820         const float* ip = &input[idx * fused_block_size];
821         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
822             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
823             ? (dataInd + prefdist_T0)
824             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
825             : dataInd;
826         const int64_t idx_pref_T0 = indices[next_T0];
827         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
828           return false;
829         }
830         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
831         j = 0;
832         for (; j + 8 <= block_size; j += 8) {
833           _mm256_storeu_ps(
834               &op[j],
835               _mm256_fmadd_ps(
836                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
837           _mm_prefetch(
838               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
839         }
840         for (; j < block_size; j++) {
841           op[j] = std::fma(wgt, ip[j], op[j]);
842         }
843       }
844       if (normalize_by_lengths && length) {
845         float len_inv = 1.0f / length;
846         __m256 vlen_inv = _mm256_set1_ps(len_inv);
847         j = 0;
848         for (; j + 8 <= block_size; j += 8) {
849           _mm256_storeu_ps(
850               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
851         }
852         for (; j < block_size; j++) {
853           op[j] = len_inv * op[j];
854         }
855       }
856     }
857   }
858   return dataInd == index_size;
859 }
EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)860 bool EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(
861     const int64_t block_size,
862     const int64_t output_size,
863     const int64_t index_size,
864     const int64_t data_size,
865     const float* input,
866     const int64_t* indices,
867     const int64_t* offsets,
868     const float* weights,
869     const float* scale_bias,
870     bool normalize_by_lengths,
871     float* out) {
872   return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<false>(
873       block_size,
874       output_size,
875       index_size,
876       data_size,
877       input,
878       indices,
879       offsets,
880       weights,
881       scale_bias,
882       normalize_by_lengths,
883       out);
884 }
EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const float * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)885 bool EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(
886     const int64_t block_size,
887     const int64_t output_size,
888     const int64_t index_size,
889     const int64_t data_size,
890     const float* input,
891     const int64_t* indices,
892     const int64_t* offsets,
893     const float* weights,
894     const float* scale_bias,
895     bool normalize_by_lengths,
896     float* out) {
897   return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<true>(
898       block_size,
899       output_size,
900       index_size,
901       data_size,
902       input,
903       indices,
904       offsets,
905       weights,
906       scale_bias,
907       normalize_by_lengths,
908       out);
909 }
910 
911 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_half_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)912 static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(
913     const int64_t block_size,
914     const int64_t output_size,
915     const int64_t index_size,
916     const int64_t data_size,
917     const at::Half* input,
918     const int* indices,
919     const int* offsets,
920     const float* weights,
921     const float* scale_bias,
922     bool normalize_by_lengths,
923     float* out) {
924   const int prefdist_T0 = 16;
925   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
926   const int fused_block_size = block_size + 0;
927   int64_t dataInd = 0;
928   if (block_size == 128) {
929     // unrolling 16 times
930     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
931       float* op = &out[rangeIndex * block_size];
932       __m256 vop0 = _mm256_setzero_ps();
933       __m256 vop8 = _mm256_setzero_ps();
934       __m256 vop16 = _mm256_setzero_ps();
935       __m256 vop24 = _mm256_setzero_ps();
936       __m256 vop32 = _mm256_setzero_ps();
937       __m256 vop40 = _mm256_setzero_ps();
938       __m256 vop48 = _mm256_setzero_ps();
939       __m256 vop56 = _mm256_setzero_ps();
940       __m256 vop64 = _mm256_setzero_ps();
941       __m256 vop72 = _mm256_setzero_ps();
942       __m256 vop80 = _mm256_setzero_ps();
943       __m256 vop88 = _mm256_setzero_ps();
944       __m256 vop96 = _mm256_setzero_ps();
945       __m256 vop104 = _mm256_setzero_ps();
946       __m256 vop112 = _mm256_setzero_ps();
947       __m256 vop120 = _mm256_setzero_ps();
948       if (dataInd != offsets[rangeIndex] - offsets[0]) {
949         return false;
950       }
951       int64_t end_offset = offsets[rangeIndex + 1];
952       int64_t length = end_offset - offsets[rangeIndex];
953       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
954            ++dataInd) {
955         const int idx = indices[dataInd];
956         if (idx < 0 || idx >= data_size) {
957           return false;
958         }
959         float wgt = 1.f;
960         if (weights) {
961           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
962         }
963         __m256 vwgt = _mm256_set1_ps(wgt);
964         const at::Half* ip = &input[idx * fused_block_size];
965         const int next_T0 = (dataInd < index_size - prefdist_T0)
966             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
967             ? (dataInd + prefdist_T0)
968             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
969             : dataInd;
970         const int idx_pref_T0 = indices[next_T0];
971         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
972           return false;
973         }
974         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
975         vop0 = _mm256_fmadd_ps(
976             vwgt,
977             _mm256_cvtph_ps(
978                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
979             vop0);
980         _mm_prefetch(
981             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
982         vop8 = _mm256_fmadd_ps(
983             vwgt,
984             _mm256_cvtph_ps(
985                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
986             vop8);
987         // skip unnecessary prefetch of (&ip_next_T0[8])
988         vop16 = _mm256_fmadd_ps(
989             vwgt,
990             _mm256_cvtph_ps(
991                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
992             vop16);
993         // skip unnecessary prefetch of (&ip_next_T0[16])
994         vop24 = _mm256_fmadd_ps(
995             vwgt,
996             _mm256_cvtph_ps(
997                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
998             vop24);
999         // skip unnecessary prefetch of (&ip_next_T0[24])
1000         vop32 = _mm256_fmadd_ps(
1001             vwgt,
1002             _mm256_cvtph_ps(
1003                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1004             vop32);
1005         _mm_prefetch(
1006             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1007         vop40 = _mm256_fmadd_ps(
1008             vwgt,
1009             _mm256_cvtph_ps(
1010                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1011             vop40);
1012         // skip unnecessary prefetch of (&ip_next_T0[40])
1013         vop48 = _mm256_fmadd_ps(
1014             vwgt,
1015             _mm256_cvtph_ps(
1016                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1017             vop48);
1018         // skip unnecessary prefetch of (&ip_next_T0[48])
1019         vop56 = _mm256_fmadd_ps(
1020             vwgt,
1021             _mm256_cvtph_ps(
1022                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1023             vop56);
1024         // skip unnecessary prefetch of (&ip_next_T0[56])
1025         vop64 = _mm256_fmadd_ps(
1026             vwgt,
1027             _mm256_cvtph_ps(
1028                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1029             vop64);
1030         _mm_prefetch(
1031             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1032         vop72 = _mm256_fmadd_ps(
1033             vwgt,
1034             _mm256_cvtph_ps(
1035                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1036             vop72);
1037         // skip unnecessary prefetch of (&ip_next_T0[72])
1038         vop80 = _mm256_fmadd_ps(
1039             vwgt,
1040             _mm256_cvtph_ps(
1041                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1042             vop80);
1043         // skip unnecessary prefetch of (&ip_next_T0[80])
1044         vop88 = _mm256_fmadd_ps(
1045             vwgt,
1046             _mm256_cvtph_ps(
1047                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1048             vop88);
1049         // skip unnecessary prefetch of (&ip_next_T0[88])
1050         vop96 = _mm256_fmadd_ps(
1051             vwgt,
1052             _mm256_cvtph_ps(
1053                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1054             vop96);
1055         _mm_prefetch(
1056             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1057         vop104 = _mm256_fmadd_ps(
1058             vwgt,
1059             _mm256_cvtph_ps(
1060                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1061             vop104);
1062         // skip unnecessary prefetch of (&ip_next_T0[104])
1063         vop112 = _mm256_fmadd_ps(
1064             vwgt,
1065             _mm256_cvtph_ps(
1066                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1067             vop112);
1068         // skip unnecessary prefetch of (&ip_next_T0[112])
1069         vop120 = _mm256_fmadd_ps(
1070             vwgt,
1071             _mm256_cvtph_ps(
1072                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1073             vop120);
1074         // skip unnecessary prefetch of (&ip_next_T0[120])
1075       }
1076       if (!normalize_by_lengths || length == 0) {
1077         _mm256_storeu_ps(&op[0], vop0);
1078         _mm256_storeu_ps(&op[8], vop8);
1079         _mm256_storeu_ps(&op[16], vop16);
1080         _mm256_storeu_ps(&op[24], vop24);
1081         _mm256_storeu_ps(&op[32], vop32);
1082         _mm256_storeu_ps(&op[40], vop40);
1083         _mm256_storeu_ps(&op[48], vop48);
1084         _mm256_storeu_ps(&op[56], vop56);
1085         _mm256_storeu_ps(&op[64], vop64);
1086         _mm256_storeu_ps(&op[72], vop72);
1087         _mm256_storeu_ps(&op[80], vop80);
1088         _mm256_storeu_ps(&op[88], vop88);
1089         _mm256_storeu_ps(&op[96], vop96);
1090         _mm256_storeu_ps(&op[104], vop104);
1091         _mm256_storeu_ps(&op[112], vop112);
1092         _mm256_storeu_ps(&op[120], vop120);
1093       } else {
1094         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1095         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1096         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1097         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1098         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1099         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1100         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1101         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1102         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1103         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1104         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1105         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1106         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1107         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1108         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1109         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1110         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1111       }
1112     }
1113   } else if (block_size == 64) {
1114     // unrolling 8 times
1115     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1116       float* op = &out[rangeIndex * block_size];
1117       __m256 vop0 = _mm256_setzero_ps();
1118       __m256 vop8 = _mm256_setzero_ps();
1119       __m256 vop16 = _mm256_setzero_ps();
1120       __m256 vop24 = _mm256_setzero_ps();
1121       __m256 vop32 = _mm256_setzero_ps();
1122       __m256 vop40 = _mm256_setzero_ps();
1123       __m256 vop48 = _mm256_setzero_ps();
1124       __m256 vop56 = _mm256_setzero_ps();
1125       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1126         return false;
1127       }
1128       int64_t end_offset = offsets[rangeIndex + 1];
1129       int64_t length = end_offset - offsets[rangeIndex];
1130       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1131            ++dataInd) {
1132         const int idx = indices[dataInd];
1133         if (idx < 0 || idx >= data_size) {
1134           return false;
1135         }
1136         float wgt = 1.f;
1137         if (weights) {
1138           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1139         }
1140         __m256 vwgt = _mm256_set1_ps(wgt);
1141         const at::Half* ip = &input[idx * fused_block_size];
1142         const int next_T0 = (dataInd < index_size - prefdist_T0)
1143             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1144             ? (dataInd + prefdist_T0)
1145             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1146             : dataInd;
1147         const int idx_pref_T0 = indices[next_T0];
1148         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1149           return false;
1150         }
1151         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1152         vop0 = _mm256_fmadd_ps(
1153             vwgt,
1154             _mm256_cvtph_ps(
1155                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1156             vop0);
1157         _mm_prefetch(
1158             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1159         vop8 = _mm256_fmadd_ps(
1160             vwgt,
1161             _mm256_cvtph_ps(
1162                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1163             vop8);
1164         // skip unnecessary prefetch of (&ip_next_T0[8])
1165         vop16 = _mm256_fmadd_ps(
1166             vwgt,
1167             _mm256_cvtph_ps(
1168                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1169             vop16);
1170         // skip unnecessary prefetch of (&ip_next_T0[16])
1171         vop24 = _mm256_fmadd_ps(
1172             vwgt,
1173             _mm256_cvtph_ps(
1174                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1175             vop24);
1176         // skip unnecessary prefetch of (&ip_next_T0[24])
1177         vop32 = _mm256_fmadd_ps(
1178             vwgt,
1179             _mm256_cvtph_ps(
1180                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1181             vop32);
1182         _mm_prefetch(
1183             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1184         vop40 = _mm256_fmadd_ps(
1185             vwgt,
1186             _mm256_cvtph_ps(
1187                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1188             vop40);
1189         // skip unnecessary prefetch of (&ip_next_T0[40])
1190         vop48 = _mm256_fmadd_ps(
1191             vwgt,
1192             _mm256_cvtph_ps(
1193                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1194             vop48);
1195         // skip unnecessary prefetch of (&ip_next_T0[48])
1196         vop56 = _mm256_fmadd_ps(
1197             vwgt,
1198             _mm256_cvtph_ps(
1199                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1200             vop56);
1201         // skip unnecessary prefetch of (&ip_next_T0[56])
1202       }
1203       if (!normalize_by_lengths || length == 0) {
1204         _mm256_storeu_ps(&op[0], vop0);
1205         _mm256_storeu_ps(&op[8], vop8);
1206         _mm256_storeu_ps(&op[16], vop16);
1207         _mm256_storeu_ps(&op[24], vop24);
1208         _mm256_storeu_ps(&op[32], vop32);
1209         _mm256_storeu_ps(&op[40], vop40);
1210         _mm256_storeu_ps(&op[48], vop48);
1211         _mm256_storeu_ps(&op[56], vop56);
1212       } else {
1213         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1214         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1215         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1216         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1217         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1218         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1219         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1220         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1221         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1222       }
1223     }
1224   } else if (block_size == 32) {
1225     // unrolling 4 times
1226     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1227       float* op = &out[rangeIndex * block_size];
1228       __m256 vop0 = _mm256_setzero_ps();
1229       __m256 vop8 = _mm256_setzero_ps();
1230       __m256 vop16 = _mm256_setzero_ps();
1231       __m256 vop24 = _mm256_setzero_ps();
1232       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1233         return false;
1234       }
1235       int64_t end_offset = offsets[rangeIndex + 1];
1236       int64_t length = end_offset - offsets[rangeIndex];
1237       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1238            ++dataInd) {
1239         const int idx = indices[dataInd];
1240         if (idx < 0 || idx >= data_size) {
1241           return false;
1242         }
1243         float wgt = 1.f;
1244         if (weights) {
1245           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1246         }
1247         __m256 vwgt = _mm256_set1_ps(wgt);
1248         const at::Half* ip = &input[idx * fused_block_size];
1249         const int next_T0 = (dataInd < index_size - prefdist_T0)
1250             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1251             ? (dataInd + prefdist_T0)
1252             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1253             : dataInd;
1254         const int idx_pref_T0 = indices[next_T0];
1255         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1256           return false;
1257         }
1258         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1259         vop0 = _mm256_fmadd_ps(
1260             vwgt,
1261             _mm256_cvtph_ps(
1262                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1263             vop0);
1264         _mm_prefetch(
1265             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1266         vop8 = _mm256_fmadd_ps(
1267             vwgt,
1268             _mm256_cvtph_ps(
1269                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1270             vop8);
1271         // skip unnecessary prefetch of (&ip_next_T0[8])
1272         vop16 = _mm256_fmadd_ps(
1273             vwgt,
1274             _mm256_cvtph_ps(
1275                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1276             vop16);
1277         // skip unnecessary prefetch of (&ip_next_T0[16])
1278         vop24 = _mm256_fmadd_ps(
1279             vwgt,
1280             _mm256_cvtph_ps(
1281                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1282             vop24);
1283         // skip unnecessary prefetch of (&ip_next_T0[24])
1284       }
1285       if (!normalize_by_lengths || length == 0) {
1286         _mm256_storeu_ps(&op[0], vop0);
1287         _mm256_storeu_ps(&op[8], vop8);
1288         _mm256_storeu_ps(&op[16], vop16);
1289         _mm256_storeu_ps(&op[24], vop24);
1290       } else {
1291         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1292         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1293         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1294         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1295         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1296       }
1297     }
1298   } else if (block_size == 16) {
1299     // unrolling 2 times
1300     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1301       float* op = &out[rangeIndex * block_size];
1302       __m256 vop0 = _mm256_setzero_ps();
1303       __m256 vop8 = _mm256_setzero_ps();
1304       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1305         return false;
1306       }
1307       int64_t end_offset = offsets[rangeIndex + 1];
1308       int64_t length = end_offset - offsets[rangeIndex];
1309       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1310            ++dataInd) {
1311         const int idx = indices[dataInd];
1312         if (idx < 0 || idx >= data_size) {
1313           return false;
1314         }
1315         float wgt = 1.f;
1316         if (weights) {
1317           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1318         }
1319         __m256 vwgt = _mm256_set1_ps(wgt);
1320         const at::Half* ip = &input[idx * fused_block_size];
1321         const int next_T0 = (dataInd < index_size - prefdist_T0)
1322             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1323             ? (dataInd + prefdist_T0)
1324             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1325             : dataInd;
1326         const int idx_pref_T0 = indices[next_T0];
1327         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1328           return false;
1329         }
1330         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1331         vop0 = _mm256_fmadd_ps(
1332             vwgt,
1333             _mm256_cvtph_ps(
1334                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1335             vop0);
1336         _mm_prefetch(
1337             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1338         vop8 = _mm256_fmadd_ps(
1339             vwgt,
1340             _mm256_cvtph_ps(
1341                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1342             vop8);
1343         // skip unnecessary prefetch of (&ip_next_T0[8])
1344       }
1345       if (!normalize_by_lengths || length == 0) {
1346         _mm256_storeu_ps(&op[0], vop0);
1347         _mm256_storeu_ps(&op[8], vop8);
1348       } else {
1349         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1350         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1351         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1352       }
1353     }
1354   } else {
1355     // generic code
1356     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1357     alignas(64) at::Half vtmp1[8] = {0};
1358     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1359       float* op = &out[rangeIndex * block_size];
1360       int64_t j = 0;
1361       for (; j + 8 <= block_size; j += 8) {
1362         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1363       }
1364       for (; j < block_size; j++) {
1365         op[j] = 0.0f;
1366       }
1367       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1368         return false;
1369       }
1370       int64_t end_offset = offsets[rangeIndex + 1];
1371       int64_t length = end_offset - offsets[rangeIndex];
1372       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1373            ++dataInd) {
1374         const int idx = indices[dataInd];
1375         if (idx < 0 || idx >= data_size) {
1376           return false;
1377         }
1378         float wgt = 1.f;
1379         if (weights) {
1380           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1381         }
1382         __m256 vwgt = _mm256_set1_ps(wgt);
1383         const at::Half* ip = &input[idx * fused_block_size];
1384         const int next_T0 = (dataInd < index_size - prefdist_T0)
1385             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1386             ? (dataInd + prefdist_T0)
1387             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1388             : dataInd;
1389         const int idx_pref_T0 = indices[next_T0];
1390         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1391           return false;
1392         }
1393         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1394         j = 0;
1395         for (; j + 8 <= block_size; j += 8) {
1396           _mm256_storeu_ps(
1397               &op[j],
1398               _mm256_fmadd_ps(
1399                   vwgt,
1400                   _mm256_cvtph_ps(_mm_loadu_si128(
1401                       reinterpret_cast<const __m128i*>(&ip[j]))),
1402                   _mm256_loadu_ps(&op[j])));
1403           _mm_prefetch(
1404               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1405         }
1406         for (; j < block_size; j++) {
1407           vtmp1[0] = ip[j];
1408           __m256 vtmp2 =
1409               _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1410           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1411         }
1412       }
1413       if (normalize_by_lengths && length) {
1414         float len_inv = 1.0f / length;
1415         __m256 vlen_inv = _mm256_set1_ps(len_inv);
1416         j = 0;
1417         for (; j + 8 <= block_size; j += 8) {
1418           _mm256_storeu_ps(
1419               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1420         }
1421         for (; j < block_size; j++) {
1422           op[j] = len_inv * op[j];
1423         }
1424       }
1425     }
1426   }
1427   return dataInd == index_size;
1428 }
EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1429 bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
1430     const int64_t block_size,
1431     const int64_t output_size,
1432     const int64_t index_size,
1433     const int64_t data_size,
1434     const at::Half* input,
1435     const int* indices,
1436     const int* offsets,
1437     const float* weights,
1438     const float* scale_bias,
1439     bool normalize_by_lengths,
1440     float* out) {
1441   return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<false>(
1442       block_size,
1443       output_size,
1444       index_size,
1445       data_size,
1446       input,
1447       indices,
1448       offsets,
1449       weights,
1450       scale_bias,
1451       normalize_by_lengths,
1452       out);
1453 }
EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1454 bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
1455     const int64_t block_size,
1456     const int64_t output_size,
1457     const int64_t index_size,
1458     const int64_t data_size,
1459     const at::Half* input,
1460     const int* indices,
1461     const int* offsets,
1462     const float* weights,
1463     const float* scale_bias,
1464     bool normalize_by_lengths,
1465     float* out) {
1466   return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<true>(
1467       block_size,
1468       output_size,
1469       index_size,
1470       data_size,
1471       input,
1472       indices,
1473       offsets,
1474       weights,
1475       scale_bias,
1476       normalize_by_lengths,
1477       out);
1478 }
1479 
1480 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_half_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1481 static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma(
1482     const int64_t block_size,
1483     const int64_t output_size,
1484     const int64_t index_size,
1485     const int64_t data_size,
1486     const at::Half* input,
1487     const int64_t* indices,
1488     const int64_t* offsets,
1489     const float* weights,
1490     const float* scale_bias,
1491     bool normalize_by_lengths,
1492     float* out) {
1493   const int64_t prefdist_T0 = 16;
1494   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1495   const int64_t fused_block_size = block_size + 0;
1496   int64_t dataInd = 0;
1497   if (block_size == 128) {
1498     // unrolling 16 times
1499     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1500       float* op = &out[rangeIndex * block_size];
1501       __m256 vop0 = _mm256_setzero_ps();
1502       __m256 vop8 = _mm256_setzero_ps();
1503       __m256 vop16 = _mm256_setzero_ps();
1504       __m256 vop24 = _mm256_setzero_ps();
1505       __m256 vop32 = _mm256_setzero_ps();
1506       __m256 vop40 = _mm256_setzero_ps();
1507       __m256 vop48 = _mm256_setzero_ps();
1508       __m256 vop56 = _mm256_setzero_ps();
1509       __m256 vop64 = _mm256_setzero_ps();
1510       __m256 vop72 = _mm256_setzero_ps();
1511       __m256 vop80 = _mm256_setzero_ps();
1512       __m256 vop88 = _mm256_setzero_ps();
1513       __m256 vop96 = _mm256_setzero_ps();
1514       __m256 vop104 = _mm256_setzero_ps();
1515       __m256 vop112 = _mm256_setzero_ps();
1516       __m256 vop120 = _mm256_setzero_ps();
1517       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1518         return false;
1519       }
1520       int64_t end_offset = offsets[rangeIndex + 1];
1521       int64_t length = end_offset - offsets[rangeIndex];
1522       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1523            ++dataInd) {
1524         const int64_t idx = indices[dataInd];
1525         if (idx < 0 || idx >= data_size) {
1526           return false;
1527         }
1528         float wgt = 1.f;
1529         if (weights) {
1530           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1531         }
1532         __m256 vwgt = _mm256_set1_ps(wgt);
1533         const at::Half* ip = &input[idx * fused_block_size];
1534         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1535             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1536             ? (dataInd + prefdist_T0)
1537             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1538             : dataInd;
1539         const int64_t idx_pref_T0 = indices[next_T0];
1540         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1541           return false;
1542         }
1543         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1544         vop0 = _mm256_fmadd_ps(
1545             vwgt,
1546             _mm256_cvtph_ps(
1547                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1548             vop0);
1549         _mm_prefetch(
1550             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1551         vop8 = _mm256_fmadd_ps(
1552             vwgt,
1553             _mm256_cvtph_ps(
1554                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1555             vop8);
1556         // skip unnecessary prefetch of (&ip_next_T0[8])
1557         vop16 = _mm256_fmadd_ps(
1558             vwgt,
1559             _mm256_cvtph_ps(
1560                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1561             vop16);
1562         // skip unnecessary prefetch of (&ip_next_T0[16])
1563         vop24 = _mm256_fmadd_ps(
1564             vwgt,
1565             _mm256_cvtph_ps(
1566                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1567             vop24);
1568         // skip unnecessary prefetch of (&ip_next_T0[24])
1569         vop32 = _mm256_fmadd_ps(
1570             vwgt,
1571             _mm256_cvtph_ps(
1572                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1573             vop32);
1574         _mm_prefetch(
1575             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1576         vop40 = _mm256_fmadd_ps(
1577             vwgt,
1578             _mm256_cvtph_ps(
1579                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1580             vop40);
1581         // skip unnecessary prefetch of (&ip_next_T0[40])
1582         vop48 = _mm256_fmadd_ps(
1583             vwgt,
1584             _mm256_cvtph_ps(
1585                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1586             vop48);
1587         // skip unnecessary prefetch of (&ip_next_T0[48])
1588         vop56 = _mm256_fmadd_ps(
1589             vwgt,
1590             _mm256_cvtph_ps(
1591                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1592             vop56);
1593         // skip unnecessary prefetch of (&ip_next_T0[56])
1594         vop64 = _mm256_fmadd_ps(
1595             vwgt,
1596             _mm256_cvtph_ps(
1597                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1598             vop64);
1599         _mm_prefetch(
1600             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1601         vop72 = _mm256_fmadd_ps(
1602             vwgt,
1603             _mm256_cvtph_ps(
1604                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1605             vop72);
1606         // skip unnecessary prefetch of (&ip_next_T0[72])
1607         vop80 = _mm256_fmadd_ps(
1608             vwgt,
1609             _mm256_cvtph_ps(
1610                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1611             vop80);
1612         // skip unnecessary prefetch of (&ip_next_T0[80])
1613         vop88 = _mm256_fmadd_ps(
1614             vwgt,
1615             _mm256_cvtph_ps(
1616                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1617             vop88);
1618         // skip unnecessary prefetch of (&ip_next_T0[88])
1619         vop96 = _mm256_fmadd_ps(
1620             vwgt,
1621             _mm256_cvtph_ps(
1622                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1623             vop96);
1624         _mm_prefetch(
1625             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1626         vop104 = _mm256_fmadd_ps(
1627             vwgt,
1628             _mm256_cvtph_ps(
1629                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1630             vop104);
1631         // skip unnecessary prefetch of (&ip_next_T0[104])
1632         vop112 = _mm256_fmadd_ps(
1633             vwgt,
1634             _mm256_cvtph_ps(
1635                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1636             vop112);
1637         // skip unnecessary prefetch of (&ip_next_T0[112])
1638         vop120 = _mm256_fmadd_ps(
1639             vwgt,
1640             _mm256_cvtph_ps(
1641                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1642             vop120);
1643         // skip unnecessary prefetch of (&ip_next_T0[120])
1644       }
1645       if (!normalize_by_lengths || length == 0) {
1646         _mm256_storeu_ps(&op[0], vop0);
1647         _mm256_storeu_ps(&op[8], vop8);
1648         _mm256_storeu_ps(&op[16], vop16);
1649         _mm256_storeu_ps(&op[24], vop24);
1650         _mm256_storeu_ps(&op[32], vop32);
1651         _mm256_storeu_ps(&op[40], vop40);
1652         _mm256_storeu_ps(&op[48], vop48);
1653         _mm256_storeu_ps(&op[56], vop56);
1654         _mm256_storeu_ps(&op[64], vop64);
1655         _mm256_storeu_ps(&op[72], vop72);
1656         _mm256_storeu_ps(&op[80], vop80);
1657         _mm256_storeu_ps(&op[88], vop88);
1658         _mm256_storeu_ps(&op[96], vop96);
1659         _mm256_storeu_ps(&op[104], vop104);
1660         _mm256_storeu_ps(&op[112], vop112);
1661         _mm256_storeu_ps(&op[120], vop120);
1662       } else {
1663         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1664         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1665         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1666         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1667         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1668         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1669         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1670         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1671         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1672         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1673         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1674         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1675         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1676         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1677         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1678         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1679         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1680       }
1681     }
1682   } else if (block_size == 64) {
1683     // unrolling 8 times
1684     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1685       float* op = &out[rangeIndex * block_size];
1686       __m256 vop0 = _mm256_setzero_ps();
1687       __m256 vop8 = _mm256_setzero_ps();
1688       __m256 vop16 = _mm256_setzero_ps();
1689       __m256 vop24 = _mm256_setzero_ps();
1690       __m256 vop32 = _mm256_setzero_ps();
1691       __m256 vop40 = _mm256_setzero_ps();
1692       __m256 vop48 = _mm256_setzero_ps();
1693       __m256 vop56 = _mm256_setzero_ps();
1694       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1695         return false;
1696       }
1697       int64_t end_offset = offsets[rangeIndex + 1];
1698       int64_t length = end_offset - offsets[rangeIndex];
1699       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1700            ++dataInd) {
1701         const int64_t idx = indices[dataInd];
1702         if (idx < 0 || idx >= data_size) {
1703           return false;
1704         }
1705         float wgt = 1.f;
1706         if (weights) {
1707           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1708         }
1709         __m256 vwgt = _mm256_set1_ps(wgt);
1710         const at::Half* ip = &input[idx * fused_block_size];
1711         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1712             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1713             ? (dataInd + prefdist_T0)
1714             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1715             : dataInd;
1716         const int64_t idx_pref_T0 = indices[next_T0];
1717         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1718           return false;
1719         }
1720         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1721         vop0 = _mm256_fmadd_ps(
1722             vwgt,
1723             _mm256_cvtph_ps(
1724                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1725             vop0);
1726         _mm_prefetch(
1727             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1728         vop8 = _mm256_fmadd_ps(
1729             vwgt,
1730             _mm256_cvtph_ps(
1731                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1732             vop8);
1733         // skip unnecessary prefetch of (&ip_next_T0[8])
1734         vop16 = _mm256_fmadd_ps(
1735             vwgt,
1736             _mm256_cvtph_ps(
1737                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1738             vop16);
1739         // skip unnecessary prefetch of (&ip_next_T0[16])
1740         vop24 = _mm256_fmadd_ps(
1741             vwgt,
1742             _mm256_cvtph_ps(
1743                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1744             vop24);
1745         // skip unnecessary prefetch of (&ip_next_T0[24])
1746         vop32 = _mm256_fmadd_ps(
1747             vwgt,
1748             _mm256_cvtph_ps(
1749                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1750             vop32);
1751         _mm_prefetch(
1752             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1753         vop40 = _mm256_fmadd_ps(
1754             vwgt,
1755             _mm256_cvtph_ps(
1756                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1757             vop40);
1758         // skip unnecessary prefetch of (&ip_next_T0[40])
1759         vop48 = _mm256_fmadd_ps(
1760             vwgt,
1761             _mm256_cvtph_ps(
1762                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1763             vop48);
1764         // skip unnecessary prefetch of (&ip_next_T0[48])
1765         vop56 = _mm256_fmadd_ps(
1766             vwgt,
1767             _mm256_cvtph_ps(
1768                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1769             vop56);
1770         // skip unnecessary prefetch of (&ip_next_T0[56])
1771       }
1772       if (!normalize_by_lengths || length == 0) {
1773         _mm256_storeu_ps(&op[0], vop0);
1774         _mm256_storeu_ps(&op[8], vop8);
1775         _mm256_storeu_ps(&op[16], vop16);
1776         _mm256_storeu_ps(&op[24], vop24);
1777         _mm256_storeu_ps(&op[32], vop32);
1778         _mm256_storeu_ps(&op[40], vop40);
1779         _mm256_storeu_ps(&op[48], vop48);
1780         _mm256_storeu_ps(&op[56], vop56);
1781       } else {
1782         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1783         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1784         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1785         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1786         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1787         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1788         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1789         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1790         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1791       }
1792     }
1793   } else if (block_size == 32) {
1794     // unrolling 4 times
1795     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1796       float* op = &out[rangeIndex * block_size];
1797       __m256 vop0 = _mm256_setzero_ps();
1798       __m256 vop8 = _mm256_setzero_ps();
1799       __m256 vop16 = _mm256_setzero_ps();
1800       __m256 vop24 = _mm256_setzero_ps();
1801       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1802         return false;
1803       }
1804       int64_t end_offset = offsets[rangeIndex + 1];
1805       int64_t length = end_offset - offsets[rangeIndex];
1806       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1807            ++dataInd) {
1808         const int64_t idx = indices[dataInd];
1809         if (idx < 0 || idx >= data_size) {
1810           return false;
1811         }
1812         float wgt = 1.f;
1813         if (weights) {
1814           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1815         }
1816         __m256 vwgt = _mm256_set1_ps(wgt);
1817         const at::Half* ip = &input[idx * fused_block_size];
1818         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1819             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1820             ? (dataInd + prefdist_T0)
1821             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1822             : dataInd;
1823         const int64_t idx_pref_T0 = indices[next_T0];
1824         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1825           return false;
1826         }
1827         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1828         vop0 = _mm256_fmadd_ps(
1829             vwgt,
1830             _mm256_cvtph_ps(
1831                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1832             vop0);
1833         _mm_prefetch(
1834             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1835         vop8 = _mm256_fmadd_ps(
1836             vwgt,
1837             _mm256_cvtph_ps(
1838                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1839             vop8);
1840         // skip unnecessary prefetch of (&ip_next_T0[8])
1841         vop16 = _mm256_fmadd_ps(
1842             vwgt,
1843             _mm256_cvtph_ps(
1844                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1845             vop16);
1846         // skip unnecessary prefetch of (&ip_next_T0[16])
1847         vop24 = _mm256_fmadd_ps(
1848             vwgt,
1849             _mm256_cvtph_ps(
1850                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1851             vop24);
1852         // skip unnecessary prefetch of (&ip_next_T0[24])
1853       }
1854       if (!normalize_by_lengths || length == 0) {
1855         _mm256_storeu_ps(&op[0], vop0);
1856         _mm256_storeu_ps(&op[8], vop8);
1857         _mm256_storeu_ps(&op[16], vop16);
1858         _mm256_storeu_ps(&op[24], vop24);
1859       } else {
1860         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1861         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1862         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1863         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1864         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1865       }
1866     }
1867   } else if (block_size == 16) {
1868     // unrolling 2 times
1869     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1870       float* op = &out[rangeIndex * block_size];
1871       __m256 vop0 = _mm256_setzero_ps();
1872       __m256 vop8 = _mm256_setzero_ps();
1873       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1874         return false;
1875       }
1876       int64_t end_offset = offsets[rangeIndex + 1];
1877       int64_t length = end_offset - offsets[rangeIndex];
1878       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1879            ++dataInd) {
1880         const int64_t idx = indices[dataInd];
1881         if (idx < 0 || idx >= data_size) {
1882           return false;
1883         }
1884         float wgt = 1.f;
1885         if (weights) {
1886           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1887         }
1888         __m256 vwgt = _mm256_set1_ps(wgt);
1889         const at::Half* ip = &input[idx * fused_block_size];
1890         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1891             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1892             ? (dataInd + prefdist_T0)
1893             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1894             : dataInd;
1895         const int64_t idx_pref_T0 = indices[next_T0];
1896         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1897           return false;
1898         }
1899         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1900         vop0 = _mm256_fmadd_ps(
1901             vwgt,
1902             _mm256_cvtph_ps(
1903                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1904             vop0);
1905         _mm_prefetch(
1906             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1907         vop8 = _mm256_fmadd_ps(
1908             vwgt,
1909             _mm256_cvtph_ps(
1910                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1911             vop8);
1912         // skip unnecessary prefetch of (&ip_next_T0[8])
1913       }
1914       if (!normalize_by_lengths || length == 0) {
1915         _mm256_storeu_ps(&op[0], vop0);
1916         _mm256_storeu_ps(&op[8], vop8);
1917       } else {
1918         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1919         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1920         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1921       }
1922     }
1923   } else {
1924     // generic code
1925     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1926     alignas(64) at::Half vtmp1[8] = {0};
1927     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1928       float* op = &out[rangeIndex * block_size];
1929       int64_t j = 0;
1930       for (; j + 8 <= block_size; j += 8) {
1931         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1932       }
1933       for (; j < block_size; j++) {
1934         op[j] = 0.0f;
1935       }
1936       if (dataInd != offsets[rangeIndex] - offsets[0]) {
1937         return false;
1938       }
1939       int64_t end_offset = offsets[rangeIndex + 1];
1940       int64_t length = end_offset - offsets[rangeIndex];
1941       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1942            ++dataInd) {
1943         const int64_t idx = indices[dataInd];
1944         if (idx < 0 || idx >= data_size) {
1945           return false;
1946         }
1947         float wgt = 1.f;
1948         if (weights) {
1949           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1950         }
1951         __m256 vwgt = _mm256_set1_ps(wgt);
1952         const at::Half* ip = &input[idx * fused_block_size];
1953         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1954             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1955             ? (dataInd + prefdist_T0)
1956             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1957             : dataInd;
1958         const int64_t idx_pref_T0 = indices[next_T0];
1959         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1960           return false;
1961         }
1962         const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1963         j = 0;
1964         for (; j + 8 <= block_size; j += 8) {
1965           _mm256_storeu_ps(
1966               &op[j],
1967               _mm256_fmadd_ps(
1968                   vwgt,
1969                   _mm256_cvtph_ps(_mm_loadu_si128(
1970                       reinterpret_cast<const __m128i*>(&ip[j]))),
1971                   _mm256_loadu_ps(&op[j])));
1972           _mm_prefetch(
1973               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1974         }
1975         for (; j < block_size; j++) {
1976           vtmp1[0] = ip[j];
1977           __m256 vtmp2 =
1978               _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1979           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1980         }
1981       }
1982       if (normalize_by_lengths && length) {
1983         float len_inv = 1.0f / length;
1984         __m256 vlen_inv = _mm256_set1_ps(len_inv);
1985         j = 0;
1986         for (; j + 8 <= block_size; j += 8) {
1987           _mm256_storeu_ps(
1988               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1989         }
1990         for (; j < block_size; j++) {
1991           op[j] = len_inv * op[j];
1992         }
1993       }
1994     }
1995   }
1996   return dataInd == index_size;
1997 }
EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)1998 bool EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(
1999     const int64_t block_size,
2000     const int64_t output_size,
2001     const int64_t index_size,
2002     const int64_t data_size,
2003     const at::Half* input,
2004     const int64_t* indices,
2005     const int64_t* offsets,
2006     const float* weights,
2007     const float* scale_bias,
2008     bool normalize_by_lengths,
2009     float* out) {
2010   return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<false>(
2011       block_size,
2012       output_size,
2013       index_size,
2014       data_size,
2015       input,
2016       indices,
2017       offsets,
2018       weights,
2019       scale_bias,
2020       normalize_by_lengths,
2021       out);
2022 }
EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::Half * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2023 bool EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(
2024     const int64_t block_size,
2025     const int64_t output_size,
2026     const int64_t index_size,
2027     const int64_t data_size,
2028     const at::Half* input,
2029     const int64_t* indices,
2030     const int64_t* offsets,
2031     const float* weights,
2032     const float* scale_bias,
2033     bool normalize_by_lengths,
2034     float* out) {
2035   return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<true>(
2036       block_size,
2037       output_size,
2038       index_size,
2039       data_size,
2040       input,
2041       indices,
2042       offsets,
2043       weights,
2044       scale_bias,
2045       normalize_by_lengths,
2046       out);
2047 }
2048 
2049 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2050 static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(
2051     const int64_t block_size,
2052     const int64_t output_size,
2053     const int64_t index_size,
2054     const int64_t data_size,
2055     const at::BFloat16* input,
2056     const int* indices,
2057     const int* offsets,
2058     const float* weights,
2059     const float* scale_bias,
2060     bool normalize_by_lengths,
2061     float* out) {
2062   const int prefdist_T0 = 16;
2063   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2064   const int fused_block_size = block_size + 0;
2065   int64_t dataInd = 0;
2066   if (block_size == 128) {
2067     // unrolling 16 times
2068     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2069       float* op = &out[rangeIndex * block_size];
2070       __m256 vop0 = _mm256_setzero_ps();
2071       __m256 vop8 = _mm256_setzero_ps();
2072       __m256 vop16 = _mm256_setzero_ps();
2073       __m256 vop24 = _mm256_setzero_ps();
2074       __m256 vop32 = _mm256_setzero_ps();
2075       __m256 vop40 = _mm256_setzero_ps();
2076       __m256 vop48 = _mm256_setzero_ps();
2077       __m256 vop56 = _mm256_setzero_ps();
2078       __m256 vop64 = _mm256_setzero_ps();
2079       __m256 vop72 = _mm256_setzero_ps();
2080       __m256 vop80 = _mm256_setzero_ps();
2081       __m256 vop88 = _mm256_setzero_ps();
2082       __m256 vop96 = _mm256_setzero_ps();
2083       __m256 vop104 = _mm256_setzero_ps();
2084       __m256 vop112 = _mm256_setzero_ps();
2085       __m256 vop120 = _mm256_setzero_ps();
2086       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2087         return false;
2088       }
2089       int64_t end_offset = offsets[rangeIndex + 1];
2090       int64_t length = end_offset - offsets[rangeIndex];
2091       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2092            ++dataInd) {
2093         const int idx = indices[dataInd];
2094         if (idx < 0 || idx >= data_size) {
2095           return false;
2096         }
2097         float wgt = 1.f;
2098         if (weights) {
2099           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2100         }
2101         __m256 vwgt = _mm256_set1_ps(wgt);
2102         const at::BFloat16* ip = &input[idx * fused_block_size];
2103         const int next_T0 = (dataInd < index_size - prefdist_T0)
2104             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2105             ? (dataInd + prefdist_T0)
2106             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2107             : dataInd;
2108         const int idx_pref_T0 = indices[next_T0];
2109         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2110           return false;
2111         }
2112         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2113         vop0 = _mm256_fmadd_ps(
2114             vwgt,
2115             _mm256_castsi256_ps(_mm256_slli_epi32(
2116                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2117                     reinterpret_cast<const __m128i*>(ip + (0)))),
2118                 16)),
2119             vop0);
2120         _mm_prefetch(
2121             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2122         vop8 = _mm256_fmadd_ps(
2123             vwgt,
2124             _mm256_castsi256_ps(_mm256_slli_epi32(
2125                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2126                     reinterpret_cast<const __m128i*>(ip + (8)))),
2127                 16)),
2128             vop8);
2129         // skip unnecessary prefetch of (&ip_next_T0[8])
2130         vop16 = _mm256_fmadd_ps(
2131             vwgt,
2132             _mm256_castsi256_ps(_mm256_slli_epi32(
2133                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2134                     reinterpret_cast<const __m128i*>(ip + (16)))),
2135                 16)),
2136             vop16);
2137         // skip unnecessary prefetch of (&ip_next_T0[16])
2138         vop24 = _mm256_fmadd_ps(
2139             vwgt,
2140             _mm256_castsi256_ps(_mm256_slli_epi32(
2141                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2142                     reinterpret_cast<const __m128i*>(ip + (24)))),
2143                 16)),
2144             vop24);
2145         // skip unnecessary prefetch of (&ip_next_T0[24])
2146         vop32 = _mm256_fmadd_ps(
2147             vwgt,
2148             _mm256_castsi256_ps(_mm256_slli_epi32(
2149                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2150                     reinterpret_cast<const __m128i*>(ip + (32)))),
2151                 16)),
2152             vop32);
2153         _mm_prefetch(
2154             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2155         vop40 = _mm256_fmadd_ps(
2156             vwgt,
2157             _mm256_castsi256_ps(_mm256_slli_epi32(
2158                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2159                     reinterpret_cast<const __m128i*>(ip + (40)))),
2160                 16)),
2161             vop40);
2162         // skip unnecessary prefetch of (&ip_next_T0[40])
2163         vop48 = _mm256_fmadd_ps(
2164             vwgt,
2165             _mm256_castsi256_ps(_mm256_slli_epi32(
2166                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2167                     reinterpret_cast<const __m128i*>(ip + (48)))),
2168                 16)),
2169             vop48);
2170         // skip unnecessary prefetch of (&ip_next_T0[48])
2171         vop56 = _mm256_fmadd_ps(
2172             vwgt,
2173             _mm256_castsi256_ps(_mm256_slli_epi32(
2174                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2175                     reinterpret_cast<const __m128i*>(ip + (56)))),
2176                 16)),
2177             vop56);
2178         // skip unnecessary prefetch of (&ip_next_T0[56])
2179         vop64 = _mm256_fmadd_ps(
2180             vwgt,
2181             _mm256_castsi256_ps(_mm256_slli_epi32(
2182                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2183                     reinterpret_cast<const __m128i*>(ip + (64)))),
2184                 16)),
2185             vop64);
2186         _mm_prefetch(
2187             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2188         vop72 = _mm256_fmadd_ps(
2189             vwgt,
2190             _mm256_castsi256_ps(_mm256_slli_epi32(
2191                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2192                     reinterpret_cast<const __m128i*>(ip + (72)))),
2193                 16)),
2194             vop72);
2195         // skip unnecessary prefetch of (&ip_next_T0[72])
2196         vop80 = _mm256_fmadd_ps(
2197             vwgt,
2198             _mm256_castsi256_ps(_mm256_slli_epi32(
2199                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2200                     reinterpret_cast<const __m128i*>(ip + (80)))),
2201                 16)),
2202             vop80);
2203         // skip unnecessary prefetch of (&ip_next_T0[80])
2204         vop88 = _mm256_fmadd_ps(
2205             vwgt,
2206             _mm256_castsi256_ps(_mm256_slli_epi32(
2207                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2208                     reinterpret_cast<const __m128i*>(ip + (88)))),
2209                 16)),
2210             vop88);
2211         // skip unnecessary prefetch of (&ip_next_T0[88])
2212         vop96 = _mm256_fmadd_ps(
2213             vwgt,
2214             _mm256_castsi256_ps(_mm256_slli_epi32(
2215                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2216                     reinterpret_cast<const __m128i*>(ip + (96)))),
2217                 16)),
2218             vop96);
2219         _mm_prefetch(
2220             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2221         vop104 = _mm256_fmadd_ps(
2222             vwgt,
2223             _mm256_castsi256_ps(_mm256_slli_epi32(
2224                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2225                     reinterpret_cast<const __m128i*>(ip + (104)))),
2226                 16)),
2227             vop104);
2228         // skip unnecessary prefetch of (&ip_next_T0[104])
2229         vop112 = _mm256_fmadd_ps(
2230             vwgt,
2231             _mm256_castsi256_ps(_mm256_slli_epi32(
2232                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2233                     reinterpret_cast<const __m128i*>(ip + (112)))),
2234                 16)),
2235             vop112);
2236         // skip unnecessary prefetch of (&ip_next_T0[112])
2237         vop120 = _mm256_fmadd_ps(
2238             vwgt,
2239             _mm256_castsi256_ps(_mm256_slli_epi32(
2240                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2241                     reinterpret_cast<const __m128i*>(ip + (120)))),
2242                 16)),
2243             vop120);
2244         // skip unnecessary prefetch of (&ip_next_T0[120])
2245       }
2246       if (!normalize_by_lengths || length == 0) {
2247         _mm256_storeu_ps(&op[0], vop0);
2248         _mm256_storeu_ps(&op[8], vop8);
2249         _mm256_storeu_ps(&op[16], vop16);
2250         _mm256_storeu_ps(&op[24], vop24);
2251         _mm256_storeu_ps(&op[32], vop32);
2252         _mm256_storeu_ps(&op[40], vop40);
2253         _mm256_storeu_ps(&op[48], vop48);
2254         _mm256_storeu_ps(&op[56], vop56);
2255         _mm256_storeu_ps(&op[64], vop64);
2256         _mm256_storeu_ps(&op[72], vop72);
2257         _mm256_storeu_ps(&op[80], vop80);
2258         _mm256_storeu_ps(&op[88], vop88);
2259         _mm256_storeu_ps(&op[96], vop96);
2260         _mm256_storeu_ps(&op[104], vop104);
2261         _mm256_storeu_ps(&op[112], vop112);
2262         _mm256_storeu_ps(&op[120], vop120);
2263       } else {
2264         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2265         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2266         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2267         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2268         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2269         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2270         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2271         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2272         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2273         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2274         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2275         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2276         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2277         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2278         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2279         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2280         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2281       }
2282     }
2283   } else if (block_size == 64) {
2284     // unrolling 8 times
2285     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2286       float* op = &out[rangeIndex * block_size];
2287       __m256 vop0 = _mm256_setzero_ps();
2288       __m256 vop8 = _mm256_setzero_ps();
2289       __m256 vop16 = _mm256_setzero_ps();
2290       __m256 vop24 = _mm256_setzero_ps();
2291       __m256 vop32 = _mm256_setzero_ps();
2292       __m256 vop40 = _mm256_setzero_ps();
2293       __m256 vop48 = _mm256_setzero_ps();
2294       __m256 vop56 = _mm256_setzero_ps();
2295       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2296         return false;
2297       }
2298       int64_t end_offset = offsets[rangeIndex + 1];
2299       int64_t length = end_offset - offsets[rangeIndex];
2300       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2301            ++dataInd) {
2302         const int idx = indices[dataInd];
2303         if (idx < 0 || idx >= data_size) {
2304           return false;
2305         }
2306         float wgt = 1.f;
2307         if (weights) {
2308           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2309         }
2310         __m256 vwgt = _mm256_set1_ps(wgt);
2311         const at::BFloat16* ip = &input[idx * fused_block_size];
2312         const int next_T0 = (dataInd < index_size - prefdist_T0)
2313             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2314             ? (dataInd + prefdist_T0)
2315             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2316             : dataInd;
2317         const int idx_pref_T0 = indices[next_T0];
2318         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2319           return false;
2320         }
2321         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2322         vop0 = _mm256_fmadd_ps(
2323             vwgt,
2324             _mm256_castsi256_ps(_mm256_slli_epi32(
2325                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2326                     reinterpret_cast<const __m128i*>(ip + (0)))),
2327                 16)),
2328             vop0);
2329         _mm_prefetch(
2330             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2331         vop8 = _mm256_fmadd_ps(
2332             vwgt,
2333             _mm256_castsi256_ps(_mm256_slli_epi32(
2334                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2335                     reinterpret_cast<const __m128i*>(ip + (8)))),
2336                 16)),
2337             vop8);
2338         // skip unnecessary prefetch of (&ip_next_T0[8])
2339         vop16 = _mm256_fmadd_ps(
2340             vwgt,
2341             _mm256_castsi256_ps(_mm256_slli_epi32(
2342                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2343                     reinterpret_cast<const __m128i*>(ip + (16)))),
2344                 16)),
2345             vop16);
2346         // skip unnecessary prefetch of (&ip_next_T0[16])
2347         vop24 = _mm256_fmadd_ps(
2348             vwgt,
2349             _mm256_castsi256_ps(_mm256_slli_epi32(
2350                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2351                     reinterpret_cast<const __m128i*>(ip + (24)))),
2352                 16)),
2353             vop24);
2354         // skip unnecessary prefetch of (&ip_next_T0[24])
2355         vop32 = _mm256_fmadd_ps(
2356             vwgt,
2357             _mm256_castsi256_ps(_mm256_slli_epi32(
2358                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2359                     reinterpret_cast<const __m128i*>(ip + (32)))),
2360                 16)),
2361             vop32);
2362         _mm_prefetch(
2363             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2364         vop40 = _mm256_fmadd_ps(
2365             vwgt,
2366             _mm256_castsi256_ps(_mm256_slli_epi32(
2367                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2368                     reinterpret_cast<const __m128i*>(ip + (40)))),
2369                 16)),
2370             vop40);
2371         // skip unnecessary prefetch of (&ip_next_T0[40])
2372         vop48 = _mm256_fmadd_ps(
2373             vwgt,
2374             _mm256_castsi256_ps(_mm256_slli_epi32(
2375                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2376                     reinterpret_cast<const __m128i*>(ip + (48)))),
2377                 16)),
2378             vop48);
2379         // skip unnecessary prefetch of (&ip_next_T0[48])
2380         vop56 = _mm256_fmadd_ps(
2381             vwgt,
2382             _mm256_castsi256_ps(_mm256_slli_epi32(
2383                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2384                     reinterpret_cast<const __m128i*>(ip + (56)))),
2385                 16)),
2386             vop56);
2387         // skip unnecessary prefetch of (&ip_next_T0[56])
2388       }
2389       if (!normalize_by_lengths || length == 0) {
2390         _mm256_storeu_ps(&op[0], vop0);
2391         _mm256_storeu_ps(&op[8], vop8);
2392         _mm256_storeu_ps(&op[16], vop16);
2393         _mm256_storeu_ps(&op[24], vop24);
2394         _mm256_storeu_ps(&op[32], vop32);
2395         _mm256_storeu_ps(&op[40], vop40);
2396         _mm256_storeu_ps(&op[48], vop48);
2397         _mm256_storeu_ps(&op[56], vop56);
2398       } else {
2399         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2400         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2401         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2402         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2403         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2404         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2405         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2406         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2407         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2408       }
2409     }
2410   } else if (block_size == 32) {
2411     // unrolling 4 times
2412     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2413       float* op = &out[rangeIndex * block_size];
2414       __m256 vop0 = _mm256_setzero_ps();
2415       __m256 vop8 = _mm256_setzero_ps();
2416       __m256 vop16 = _mm256_setzero_ps();
2417       __m256 vop24 = _mm256_setzero_ps();
2418       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2419         return false;
2420       }
2421       int64_t end_offset = offsets[rangeIndex + 1];
2422       int64_t length = end_offset - offsets[rangeIndex];
2423       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2424            ++dataInd) {
2425         const int idx = indices[dataInd];
2426         if (idx < 0 || idx >= data_size) {
2427           return false;
2428         }
2429         float wgt = 1.f;
2430         if (weights) {
2431           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2432         }
2433         __m256 vwgt = _mm256_set1_ps(wgt);
2434         const at::BFloat16* ip = &input[idx * fused_block_size];
2435         const int next_T0 = (dataInd < index_size - prefdist_T0)
2436             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2437             ? (dataInd + prefdist_T0)
2438             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2439             : dataInd;
2440         const int idx_pref_T0 = indices[next_T0];
2441         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2442           return false;
2443         }
2444         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2445         vop0 = _mm256_fmadd_ps(
2446             vwgt,
2447             _mm256_castsi256_ps(_mm256_slli_epi32(
2448                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2449                     reinterpret_cast<const __m128i*>(ip + (0)))),
2450                 16)),
2451             vop0);
2452         _mm_prefetch(
2453             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2454         vop8 = _mm256_fmadd_ps(
2455             vwgt,
2456             _mm256_castsi256_ps(_mm256_slli_epi32(
2457                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2458                     reinterpret_cast<const __m128i*>(ip + (8)))),
2459                 16)),
2460             vop8);
2461         // skip unnecessary prefetch of (&ip_next_T0[8])
2462         vop16 = _mm256_fmadd_ps(
2463             vwgt,
2464             _mm256_castsi256_ps(_mm256_slli_epi32(
2465                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2466                     reinterpret_cast<const __m128i*>(ip + (16)))),
2467                 16)),
2468             vop16);
2469         // skip unnecessary prefetch of (&ip_next_T0[16])
2470         vop24 = _mm256_fmadd_ps(
2471             vwgt,
2472             _mm256_castsi256_ps(_mm256_slli_epi32(
2473                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2474                     reinterpret_cast<const __m128i*>(ip + (24)))),
2475                 16)),
2476             vop24);
2477         // skip unnecessary prefetch of (&ip_next_T0[24])
2478       }
2479       if (!normalize_by_lengths || length == 0) {
2480         _mm256_storeu_ps(&op[0], vop0);
2481         _mm256_storeu_ps(&op[8], vop8);
2482         _mm256_storeu_ps(&op[16], vop16);
2483         _mm256_storeu_ps(&op[24], vop24);
2484       } else {
2485         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2486         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2487         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2488         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2489         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2490       }
2491     }
2492   } else if (block_size == 16) {
2493     // unrolling 2 times
2494     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2495       float* op = &out[rangeIndex * block_size];
2496       __m256 vop0 = _mm256_setzero_ps();
2497       __m256 vop8 = _mm256_setzero_ps();
2498       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2499         return false;
2500       }
2501       int64_t end_offset = offsets[rangeIndex + 1];
2502       int64_t length = end_offset - offsets[rangeIndex];
2503       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2504            ++dataInd) {
2505         const int idx = indices[dataInd];
2506         if (idx < 0 || idx >= data_size) {
2507           return false;
2508         }
2509         float wgt = 1.f;
2510         if (weights) {
2511           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2512         }
2513         __m256 vwgt = _mm256_set1_ps(wgt);
2514         const at::BFloat16* ip = &input[idx * fused_block_size];
2515         const int next_T0 = (dataInd < index_size - prefdist_T0)
2516             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2517             ? (dataInd + prefdist_T0)
2518             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2519             : dataInd;
2520         const int idx_pref_T0 = indices[next_T0];
2521         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2522           return false;
2523         }
2524         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2525         vop0 = _mm256_fmadd_ps(
2526             vwgt,
2527             _mm256_castsi256_ps(_mm256_slli_epi32(
2528                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2529                     reinterpret_cast<const __m128i*>(ip + (0)))),
2530                 16)),
2531             vop0);
2532         _mm_prefetch(
2533             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2534         vop8 = _mm256_fmadd_ps(
2535             vwgt,
2536             _mm256_castsi256_ps(_mm256_slli_epi32(
2537                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2538                     reinterpret_cast<const __m128i*>(ip + (8)))),
2539                 16)),
2540             vop8);
2541         // skip unnecessary prefetch of (&ip_next_T0[8])
2542       }
2543       if (!normalize_by_lengths || length == 0) {
2544         _mm256_storeu_ps(&op[0], vop0);
2545         _mm256_storeu_ps(&op[8], vop8);
2546       } else {
2547         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2548         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2549         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2550       }
2551     }
2552   } else {
2553     // generic code
2554     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
2555     alignas(64) at::BFloat16 vtmp1[8] = {0};
2556     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2557       float* op = &out[rangeIndex * block_size];
2558       int64_t j = 0;
2559       for (; j + 8 <= block_size; j += 8) {
2560         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2561       }
2562       for (; j < block_size; j++) {
2563         op[j] = 0.0f;
2564       }
2565       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2566         return false;
2567       }
2568       int64_t end_offset = offsets[rangeIndex + 1];
2569       int64_t length = end_offset - offsets[rangeIndex];
2570       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2571            ++dataInd) {
2572         const int idx = indices[dataInd];
2573         if (idx < 0 || idx >= data_size) {
2574           return false;
2575         }
2576         float wgt = 1.f;
2577         if (weights) {
2578           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2579         }
2580         __m256 vwgt = _mm256_set1_ps(wgt);
2581         const at::BFloat16* ip = &input[idx * fused_block_size];
2582         const int next_T0 = (dataInd < index_size - prefdist_T0)
2583             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2584             ? (dataInd + prefdist_T0)
2585             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2586             : dataInd;
2587         const int idx_pref_T0 = indices[next_T0];
2588         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2589           return false;
2590         }
2591         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2592         j = 0;
2593         for (; j + 8 <= block_size; j += 8) {
2594           _mm256_storeu_ps(
2595               &op[j],
2596               _mm256_fmadd_ps(
2597                   vwgt,
2598                   _mm256_castsi256_ps(_mm256_slli_epi32(
2599                       _mm256_cvtepu16_epi32(_mm_loadu_si128(
2600                           reinterpret_cast<const __m128i*>(&ip[j]))),
2601                       16)),
2602                   _mm256_loadu_ps(&op[j])));
2603           _mm_prefetch(
2604               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2605         }
2606         for (; j < block_size; j++) {
2607           vtmp1[0] = ip[j];
2608           __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
2609               _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
2610               16));
2611           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
2612         }
2613       }
2614       if (normalize_by_lengths && length) {
2615         float len_inv = 1.0f / length;
2616         __m256 vlen_inv = _mm256_set1_ps(len_inv);
2617         j = 0;
2618         for (; j + 8 <= block_size; j += 8) {
2619           _mm256_storeu_ps(
2620               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2621         }
2622         for (; j < block_size; j++) {
2623           op[j] = len_inv * op[j];
2624         }
2625       }
2626     }
2627   }
2628   return dataInd == index_size;
2629 }
EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2630 bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(
2631     const int64_t block_size,
2632     const int64_t output_size,
2633     const int64_t index_size,
2634     const int64_t data_size,
2635     const at::BFloat16* input,
2636     const int* indices,
2637     const int* offsets,
2638     const float* weights,
2639     const float* scale_bias,
2640     bool normalize_by_lengths,
2641     float* out) {
2642   return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(
2643       block_size,
2644       output_size,
2645       index_size,
2646       data_size,
2647       input,
2648       indices,
2649       offsets,
2650       weights,
2651       scale_bias,
2652       normalize_by_lengths,
2653       out);
2654 }
EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2655 bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(
2656     const int64_t block_size,
2657     const int64_t output_size,
2658     const int64_t index_size,
2659     const int64_t data_size,
2660     const at::BFloat16* input,
2661     const int* indices,
2662     const int* offsets,
2663     const float* weights,
2664     const float* scale_bias,
2665     bool normalize_by_lengths,
2666     float* out) {
2667   return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(
2668       block_size,
2669       output_size,
2670       index_size,
2671       data_size,
2672       input,
2673       indices,
2674       offsets,
2675       weights,
2676       scale_bias,
2677       normalize_by_lengths,
2678       out);
2679 }
2680 
2681 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)2682 static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(
2683     const int64_t block_size,
2684     const int64_t output_size,
2685     const int64_t index_size,
2686     const int64_t data_size,
2687     const at::BFloat16* input,
2688     const int64_t* indices,
2689     const int64_t* offsets,
2690     const float* weights,
2691     const float* scale_bias,
2692     bool normalize_by_lengths,
2693     float* out) {
2694   const int64_t prefdist_T0 = 16;
2695   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2696   const int64_t fused_block_size = block_size + 0;
2697   int64_t dataInd = 0;
2698   if (block_size == 128) {
2699     // unrolling 16 times
2700     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2701       float* op = &out[rangeIndex * block_size];
2702       __m256 vop0 = _mm256_setzero_ps();
2703       __m256 vop8 = _mm256_setzero_ps();
2704       __m256 vop16 = _mm256_setzero_ps();
2705       __m256 vop24 = _mm256_setzero_ps();
2706       __m256 vop32 = _mm256_setzero_ps();
2707       __m256 vop40 = _mm256_setzero_ps();
2708       __m256 vop48 = _mm256_setzero_ps();
2709       __m256 vop56 = _mm256_setzero_ps();
2710       __m256 vop64 = _mm256_setzero_ps();
2711       __m256 vop72 = _mm256_setzero_ps();
2712       __m256 vop80 = _mm256_setzero_ps();
2713       __m256 vop88 = _mm256_setzero_ps();
2714       __m256 vop96 = _mm256_setzero_ps();
2715       __m256 vop104 = _mm256_setzero_ps();
2716       __m256 vop112 = _mm256_setzero_ps();
2717       __m256 vop120 = _mm256_setzero_ps();
2718       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2719         return false;
2720       }
2721       int64_t end_offset = offsets[rangeIndex + 1];
2722       int64_t length = end_offset - offsets[rangeIndex];
2723       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2724            ++dataInd) {
2725         const int64_t idx = indices[dataInd];
2726         if (idx < 0 || idx >= data_size) {
2727           return false;
2728         }
2729         float wgt = 1.f;
2730         if (weights) {
2731           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2732         }
2733         __m256 vwgt = _mm256_set1_ps(wgt);
2734         const at::BFloat16* ip = &input[idx * fused_block_size];
2735         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2736             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2737             ? (dataInd + prefdist_T0)
2738             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2739             : dataInd;
2740         const int64_t idx_pref_T0 = indices[next_T0];
2741         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2742           return false;
2743         }
2744         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2745         vop0 = _mm256_fmadd_ps(
2746             vwgt,
2747             _mm256_castsi256_ps(_mm256_slli_epi32(
2748                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2749                     reinterpret_cast<const __m128i*>(ip + (0)))),
2750                 16)),
2751             vop0);
2752         _mm_prefetch(
2753             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2754         vop8 = _mm256_fmadd_ps(
2755             vwgt,
2756             _mm256_castsi256_ps(_mm256_slli_epi32(
2757                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2758                     reinterpret_cast<const __m128i*>(ip + (8)))),
2759                 16)),
2760             vop8);
2761         // skip unnecessary prefetch of (&ip_next_T0[8])
2762         vop16 = _mm256_fmadd_ps(
2763             vwgt,
2764             _mm256_castsi256_ps(_mm256_slli_epi32(
2765                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2766                     reinterpret_cast<const __m128i*>(ip + (16)))),
2767                 16)),
2768             vop16);
2769         // skip unnecessary prefetch of (&ip_next_T0[16])
2770         vop24 = _mm256_fmadd_ps(
2771             vwgt,
2772             _mm256_castsi256_ps(_mm256_slli_epi32(
2773                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2774                     reinterpret_cast<const __m128i*>(ip + (24)))),
2775                 16)),
2776             vop24);
2777         // skip unnecessary prefetch of (&ip_next_T0[24])
2778         vop32 = _mm256_fmadd_ps(
2779             vwgt,
2780             _mm256_castsi256_ps(_mm256_slli_epi32(
2781                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2782                     reinterpret_cast<const __m128i*>(ip + (32)))),
2783                 16)),
2784             vop32);
2785         _mm_prefetch(
2786             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2787         vop40 = _mm256_fmadd_ps(
2788             vwgt,
2789             _mm256_castsi256_ps(_mm256_slli_epi32(
2790                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2791                     reinterpret_cast<const __m128i*>(ip + (40)))),
2792                 16)),
2793             vop40);
2794         // skip unnecessary prefetch of (&ip_next_T0[40])
2795         vop48 = _mm256_fmadd_ps(
2796             vwgt,
2797             _mm256_castsi256_ps(_mm256_slli_epi32(
2798                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2799                     reinterpret_cast<const __m128i*>(ip + (48)))),
2800                 16)),
2801             vop48);
2802         // skip unnecessary prefetch of (&ip_next_T0[48])
2803         vop56 = _mm256_fmadd_ps(
2804             vwgt,
2805             _mm256_castsi256_ps(_mm256_slli_epi32(
2806                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2807                     reinterpret_cast<const __m128i*>(ip + (56)))),
2808                 16)),
2809             vop56);
2810         // skip unnecessary prefetch of (&ip_next_T0[56])
2811         vop64 = _mm256_fmadd_ps(
2812             vwgt,
2813             _mm256_castsi256_ps(_mm256_slli_epi32(
2814                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2815                     reinterpret_cast<const __m128i*>(ip + (64)))),
2816                 16)),
2817             vop64);
2818         _mm_prefetch(
2819             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2820         vop72 = _mm256_fmadd_ps(
2821             vwgt,
2822             _mm256_castsi256_ps(_mm256_slli_epi32(
2823                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2824                     reinterpret_cast<const __m128i*>(ip + (72)))),
2825                 16)),
2826             vop72);
2827         // skip unnecessary prefetch of (&ip_next_T0[72])
2828         vop80 = _mm256_fmadd_ps(
2829             vwgt,
2830             _mm256_castsi256_ps(_mm256_slli_epi32(
2831                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2832                     reinterpret_cast<const __m128i*>(ip + (80)))),
2833                 16)),
2834             vop80);
2835         // skip unnecessary prefetch of (&ip_next_T0[80])
2836         vop88 = _mm256_fmadd_ps(
2837             vwgt,
2838             _mm256_castsi256_ps(_mm256_slli_epi32(
2839                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2840                     reinterpret_cast<const __m128i*>(ip + (88)))),
2841                 16)),
2842             vop88);
2843         // skip unnecessary prefetch of (&ip_next_T0[88])
2844         vop96 = _mm256_fmadd_ps(
2845             vwgt,
2846             _mm256_castsi256_ps(_mm256_slli_epi32(
2847                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2848                     reinterpret_cast<const __m128i*>(ip + (96)))),
2849                 16)),
2850             vop96);
2851         _mm_prefetch(
2852             reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2853         vop104 = _mm256_fmadd_ps(
2854             vwgt,
2855             _mm256_castsi256_ps(_mm256_slli_epi32(
2856                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2857                     reinterpret_cast<const __m128i*>(ip + (104)))),
2858                 16)),
2859             vop104);
2860         // skip unnecessary prefetch of (&ip_next_T0[104])
2861         vop112 = _mm256_fmadd_ps(
2862             vwgt,
2863             _mm256_castsi256_ps(_mm256_slli_epi32(
2864                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2865                     reinterpret_cast<const __m128i*>(ip + (112)))),
2866                 16)),
2867             vop112);
2868         // skip unnecessary prefetch of (&ip_next_T0[112])
2869         vop120 = _mm256_fmadd_ps(
2870             vwgt,
2871             _mm256_castsi256_ps(_mm256_slli_epi32(
2872                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2873                     reinterpret_cast<const __m128i*>(ip + (120)))),
2874                 16)),
2875             vop120);
2876         // skip unnecessary prefetch of (&ip_next_T0[120])
2877       }
2878       if (!normalize_by_lengths || length == 0) {
2879         _mm256_storeu_ps(&op[0], vop0);
2880         _mm256_storeu_ps(&op[8], vop8);
2881         _mm256_storeu_ps(&op[16], vop16);
2882         _mm256_storeu_ps(&op[24], vop24);
2883         _mm256_storeu_ps(&op[32], vop32);
2884         _mm256_storeu_ps(&op[40], vop40);
2885         _mm256_storeu_ps(&op[48], vop48);
2886         _mm256_storeu_ps(&op[56], vop56);
2887         _mm256_storeu_ps(&op[64], vop64);
2888         _mm256_storeu_ps(&op[72], vop72);
2889         _mm256_storeu_ps(&op[80], vop80);
2890         _mm256_storeu_ps(&op[88], vop88);
2891         _mm256_storeu_ps(&op[96], vop96);
2892         _mm256_storeu_ps(&op[104], vop104);
2893         _mm256_storeu_ps(&op[112], vop112);
2894         _mm256_storeu_ps(&op[120], vop120);
2895       } else {
2896         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2897         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2898         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2899         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2900         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2901         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2902         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2903         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2904         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2905         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2906         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2907         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2908         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2909         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2910         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2911         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2912         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2913       }
2914     }
2915   } else if (block_size == 64) {
2916     // unrolling 8 times
2917     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2918       float* op = &out[rangeIndex * block_size];
2919       __m256 vop0 = _mm256_setzero_ps();
2920       __m256 vop8 = _mm256_setzero_ps();
2921       __m256 vop16 = _mm256_setzero_ps();
2922       __m256 vop24 = _mm256_setzero_ps();
2923       __m256 vop32 = _mm256_setzero_ps();
2924       __m256 vop40 = _mm256_setzero_ps();
2925       __m256 vop48 = _mm256_setzero_ps();
2926       __m256 vop56 = _mm256_setzero_ps();
2927       if (dataInd != offsets[rangeIndex] - offsets[0]) {
2928         return false;
2929       }
2930       int64_t end_offset = offsets[rangeIndex + 1];
2931       int64_t length = end_offset - offsets[rangeIndex];
2932       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2933            ++dataInd) {
2934         const int64_t idx = indices[dataInd];
2935         if (idx < 0 || idx >= data_size) {
2936           return false;
2937         }
2938         float wgt = 1.f;
2939         if (weights) {
2940           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2941         }
2942         __m256 vwgt = _mm256_set1_ps(wgt);
2943         const at::BFloat16* ip = &input[idx * fused_block_size];
2944         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2945             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2946             ? (dataInd + prefdist_T0)
2947             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2948             : dataInd;
2949         const int64_t idx_pref_T0 = indices[next_T0];
2950         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2951           return false;
2952         }
2953         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2954         vop0 = _mm256_fmadd_ps(
2955             vwgt,
2956             _mm256_castsi256_ps(_mm256_slli_epi32(
2957                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2958                     reinterpret_cast<const __m128i*>(ip + (0)))),
2959                 16)),
2960             vop0);
2961         _mm_prefetch(
2962             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2963         vop8 = _mm256_fmadd_ps(
2964             vwgt,
2965             _mm256_castsi256_ps(_mm256_slli_epi32(
2966                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2967                     reinterpret_cast<const __m128i*>(ip + (8)))),
2968                 16)),
2969             vop8);
2970         // skip unnecessary prefetch of (&ip_next_T0[8])
2971         vop16 = _mm256_fmadd_ps(
2972             vwgt,
2973             _mm256_castsi256_ps(_mm256_slli_epi32(
2974                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2975                     reinterpret_cast<const __m128i*>(ip + (16)))),
2976                 16)),
2977             vop16);
2978         // skip unnecessary prefetch of (&ip_next_T0[16])
2979         vop24 = _mm256_fmadd_ps(
2980             vwgt,
2981             _mm256_castsi256_ps(_mm256_slli_epi32(
2982                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2983                     reinterpret_cast<const __m128i*>(ip + (24)))),
2984                 16)),
2985             vop24);
2986         // skip unnecessary prefetch of (&ip_next_T0[24])
2987         vop32 = _mm256_fmadd_ps(
2988             vwgt,
2989             _mm256_castsi256_ps(_mm256_slli_epi32(
2990                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
2991                     reinterpret_cast<const __m128i*>(ip + (32)))),
2992                 16)),
2993             vop32);
2994         _mm_prefetch(
2995             reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2996         vop40 = _mm256_fmadd_ps(
2997             vwgt,
2998             _mm256_castsi256_ps(_mm256_slli_epi32(
2999                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3000                     reinterpret_cast<const __m128i*>(ip + (40)))),
3001                 16)),
3002             vop40);
3003         // skip unnecessary prefetch of (&ip_next_T0[40])
3004         vop48 = _mm256_fmadd_ps(
3005             vwgt,
3006             _mm256_castsi256_ps(_mm256_slli_epi32(
3007                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3008                     reinterpret_cast<const __m128i*>(ip + (48)))),
3009                 16)),
3010             vop48);
3011         // skip unnecessary prefetch of (&ip_next_T0[48])
3012         vop56 = _mm256_fmadd_ps(
3013             vwgt,
3014             _mm256_castsi256_ps(_mm256_slli_epi32(
3015                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3016                     reinterpret_cast<const __m128i*>(ip + (56)))),
3017                 16)),
3018             vop56);
3019         // skip unnecessary prefetch of (&ip_next_T0[56])
3020       }
3021       if (!normalize_by_lengths || length == 0) {
3022         _mm256_storeu_ps(&op[0], vop0);
3023         _mm256_storeu_ps(&op[8], vop8);
3024         _mm256_storeu_ps(&op[16], vop16);
3025         _mm256_storeu_ps(&op[24], vop24);
3026         _mm256_storeu_ps(&op[32], vop32);
3027         _mm256_storeu_ps(&op[40], vop40);
3028         _mm256_storeu_ps(&op[48], vop48);
3029         _mm256_storeu_ps(&op[56], vop56);
3030       } else {
3031         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3032         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3033         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3034         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3035         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3036         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3037         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3038         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3039         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3040       }
3041     }
3042   } else if (block_size == 32) {
3043     // unrolling 4 times
3044     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3045       float* op = &out[rangeIndex * block_size];
3046       __m256 vop0 = _mm256_setzero_ps();
3047       __m256 vop8 = _mm256_setzero_ps();
3048       __m256 vop16 = _mm256_setzero_ps();
3049       __m256 vop24 = _mm256_setzero_ps();
3050       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3051         return false;
3052       }
3053       int64_t end_offset = offsets[rangeIndex + 1];
3054       int64_t length = end_offset - offsets[rangeIndex];
3055       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3056            ++dataInd) {
3057         const int64_t idx = indices[dataInd];
3058         if (idx < 0 || idx >= data_size) {
3059           return false;
3060         }
3061         float wgt = 1.f;
3062         if (weights) {
3063           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3064         }
3065         __m256 vwgt = _mm256_set1_ps(wgt);
3066         const at::BFloat16* ip = &input[idx * fused_block_size];
3067         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3068             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3069             ? (dataInd + prefdist_T0)
3070             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3071             : dataInd;
3072         const int64_t idx_pref_T0 = indices[next_T0];
3073         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3074           return false;
3075         }
3076         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3077         vop0 = _mm256_fmadd_ps(
3078             vwgt,
3079             _mm256_castsi256_ps(_mm256_slli_epi32(
3080                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3081                     reinterpret_cast<const __m128i*>(ip + (0)))),
3082                 16)),
3083             vop0);
3084         _mm_prefetch(
3085             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3086         vop8 = _mm256_fmadd_ps(
3087             vwgt,
3088             _mm256_castsi256_ps(_mm256_slli_epi32(
3089                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3090                     reinterpret_cast<const __m128i*>(ip + (8)))),
3091                 16)),
3092             vop8);
3093         // skip unnecessary prefetch of (&ip_next_T0[8])
3094         vop16 = _mm256_fmadd_ps(
3095             vwgt,
3096             _mm256_castsi256_ps(_mm256_slli_epi32(
3097                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3098                     reinterpret_cast<const __m128i*>(ip + (16)))),
3099                 16)),
3100             vop16);
3101         // skip unnecessary prefetch of (&ip_next_T0[16])
3102         vop24 = _mm256_fmadd_ps(
3103             vwgt,
3104             _mm256_castsi256_ps(_mm256_slli_epi32(
3105                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3106                     reinterpret_cast<const __m128i*>(ip + (24)))),
3107                 16)),
3108             vop24);
3109         // skip unnecessary prefetch of (&ip_next_T0[24])
3110       }
3111       if (!normalize_by_lengths || length == 0) {
3112         _mm256_storeu_ps(&op[0], vop0);
3113         _mm256_storeu_ps(&op[8], vop8);
3114         _mm256_storeu_ps(&op[16], vop16);
3115         _mm256_storeu_ps(&op[24], vop24);
3116       } else {
3117         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3118         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3119         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3120         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3121         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3122       }
3123     }
3124   } else if (block_size == 16) {
3125     // unrolling 2 times
3126     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3127       float* op = &out[rangeIndex * block_size];
3128       __m256 vop0 = _mm256_setzero_ps();
3129       __m256 vop8 = _mm256_setzero_ps();
3130       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3131         return false;
3132       }
3133       int64_t end_offset = offsets[rangeIndex + 1];
3134       int64_t length = end_offset - offsets[rangeIndex];
3135       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3136            ++dataInd) {
3137         const int64_t idx = indices[dataInd];
3138         if (idx < 0 || idx >= data_size) {
3139           return false;
3140         }
3141         float wgt = 1.f;
3142         if (weights) {
3143           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3144         }
3145         __m256 vwgt = _mm256_set1_ps(wgt);
3146         const at::BFloat16* ip = &input[idx * fused_block_size];
3147         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3148             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3149             ? (dataInd + prefdist_T0)
3150             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3151             : dataInd;
3152         const int64_t idx_pref_T0 = indices[next_T0];
3153         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3154           return false;
3155         }
3156         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3157         vop0 = _mm256_fmadd_ps(
3158             vwgt,
3159             _mm256_castsi256_ps(_mm256_slli_epi32(
3160                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3161                     reinterpret_cast<const __m128i*>(ip + (0)))),
3162                 16)),
3163             vop0);
3164         _mm_prefetch(
3165             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3166         vop8 = _mm256_fmadd_ps(
3167             vwgt,
3168             _mm256_castsi256_ps(_mm256_slli_epi32(
3169                 _mm256_cvtepu16_epi32(_mm_loadu_si128(
3170                     reinterpret_cast<const __m128i*>(ip + (8)))),
3171                 16)),
3172             vop8);
3173         // skip unnecessary prefetch of (&ip_next_T0[8])
3174       }
3175       if (!normalize_by_lengths || length == 0) {
3176         _mm256_storeu_ps(&op[0], vop0);
3177         _mm256_storeu_ps(&op[8], vop8);
3178       } else {
3179         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3180         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3181         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3182       }
3183     }
3184   } else {
3185     // generic code
3186     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3187     alignas(64) at::BFloat16 vtmp1[8] = {0};
3188     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3189       float* op = &out[rangeIndex * block_size];
3190       int64_t j = 0;
3191       for (; j + 8 <= block_size; j += 8) {
3192         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3193       }
3194       for (; j < block_size; j++) {
3195         op[j] = 0.0f;
3196       }
3197       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3198         return false;
3199       }
3200       int64_t end_offset = offsets[rangeIndex + 1];
3201       int64_t length = end_offset - offsets[rangeIndex];
3202       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3203            ++dataInd) {
3204         const int64_t idx = indices[dataInd];
3205         if (idx < 0 || idx >= data_size) {
3206           return false;
3207         }
3208         float wgt = 1.f;
3209         if (weights) {
3210           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3211         }
3212         __m256 vwgt = _mm256_set1_ps(wgt);
3213         const at::BFloat16* ip = &input[idx * fused_block_size];
3214         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3215             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3216             ? (dataInd + prefdist_T0)
3217             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3218             : dataInd;
3219         const int64_t idx_pref_T0 = indices[next_T0];
3220         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3221           return false;
3222         }
3223         const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3224         j = 0;
3225         for (; j + 8 <= block_size; j += 8) {
3226           _mm256_storeu_ps(
3227               &op[j],
3228               _mm256_fmadd_ps(
3229                   vwgt,
3230                   _mm256_castsi256_ps(_mm256_slli_epi32(
3231                       _mm256_cvtepu16_epi32(_mm_loadu_si128(
3232                           reinterpret_cast<const __m128i*>(&ip[j]))),
3233                       16)),
3234                   _mm256_loadu_ps(&op[j])));
3235           _mm_prefetch(
3236               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3237         }
3238         for (; j < block_size; j++) {
3239           vtmp1[0] = ip[j];
3240           __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
3241               _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
3242               16));
3243           op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
3244         }
3245       }
3246       if (normalize_by_lengths && length) {
3247         float len_inv = 1.0f / length;
3248         __m256 vlen_inv = _mm256_set1_ps(len_inv);
3249         j = 0;
3250         for (; j + 8 <= block_size; j += 8) {
3251           _mm256_storeu_ps(
3252               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3253         }
3254         for (; j < block_size; j++) {
3255           op[j] = len_inv * op[j];
3256         }
3257       }
3258     }
3259   }
3260   return dataInd == index_size;
3261 }
EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3262 bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(
3263     const int64_t block_size,
3264     const int64_t output_size,
3265     const int64_t index_size,
3266     const int64_t data_size,
3267     const at::BFloat16* input,
3268     const int64_t* indices,
3269     const int64_t* offsets,
3270     const float* weights,
3271     const float* scale_bias,
3272     bool normalize_by_lengths,
3273     float* out) {
3274   return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(
3275       block_size,
3276       output_size,
3277       index_size,
3278       data_size,
3279       input,
3280       indices,
3281       offsets,
3282       weights,
3283       scale_bias,
3284       normalize_by_lengths,
3285       out);
3286 }
EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const at::BFloat16 * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3287 bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(
3288     const int64_t block_size,
3289     const int64_t output_size,
3290     const int64_t index_size,
3291     const int64_t data_size,
3292     const at::BFloat16* input,
3293     const int64_t* indices,
3294     const int64_t* offsets,
3295     const float* weights,
3296     const float* scale_bias,
3297     bool normalize_by_lengths,
3298     float* out) {
3299   return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(
3300       block_size,
3301       output_size,
3302       index_size,
3303       data_size,
3304       input,
3305       indices,
3306       offsets,
3307       weights,
3308       scale_bias,
3309       normalize_by_lengths,
3310       out);
3311 }
3312 
3313 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3314 static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
3315     const int64_t block_size,
3316     const int64_t output_size,
3317     const int64_t index_size,
3318     const int64_t data_size,
3319     const uint8_t* input,
3320     const int* indices,
3321     const int* offsets,
3322     const float* weights,
3323     const float* scale_bias,
3324     bool normalize_by_lengths,
3325     float* out) {
3326   const int prefdist_T0 = 16;
3327   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3328   const int fused_block_size = block_size + 0;
3329   int64_t dataInd = 0;
3330   if (block_size == 128) {
3331     // unrolling 16 times
3332     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3333       float* op = &out[rangeIndex * block_size];
3334       __m256 vop0 = _mm256_setzero_ps();
3335       __m256 vop8 = _mm256_setzero_ps();
3336       __m256 vop16 = _mm256_setzero_ps();
3337       __m256 vop24 = _mm256_setzero_ps();
3338       __m256 vop32 = _mm256_setzero_ps();
3339       __m256 vop40 = _mm256_setzero_ps();
3340       __m256 vop48 = _mm256_setzero_ps();
3341       __m256 vop56 = _mm256_setzero_ps();
3342       __m256 vop64 = _mm256_setzero_ps();
3343       __m256 vop72 = _mm256_setzero_ps();
3344       __m256 vop80 = _mm256_setzero_ps();
3345       __m256 vop88 = _mm256_setzero_ps();
3346       __m256 vop96 = _mm256_setzero_ps();
3347       __m256 vop104 = _mm256_setzero_ps();
3348       __m256 vop112 = _mm256_setzero_ps();
3349       __m256 vop120 = _mm256_setzero_ps();
3350       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3351         return false;
3352       }
3353       int64_t end_offset = offsets[rangeIndex + 1];
3354       int64_t length = end_offset - offsets[rangeIndex];
3355       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3356            ++dataInd) {
3357         const int idx = indices[dataInd];
3358         if (idx < 0 || idx >= data_size) {
3359           return false;
3360         }
3361         float wgt = 1.f;
3362         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3363         float bio;
3364         if (weights) {
3365           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3366         }
3367         bio = wgt * scale_bias[2 * idx + 1];
3368         wgt = wgt * scale_bias[2 * idx];
3369         __m256 vbio = _mm256_set1_ps(bio);
3370         __m256 vwgt = _mm256_set1_ps(wgt);
3371         const uint8_t* ip = &input[idx * fused_block_size];
3372         const int next_T0 = (dataInd < index_size - prefdist_T0)
3373             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3374             ? (dataInd + prefdist_T0)
3375             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3376             : dataInd;
3377         const int idx_pref_T0 = indices[next_T0];
3378         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3379           return false;
3380         }
3381         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3382         vop0 = _mm256_fmadd_ps(
3383             vwgt,
3384             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3385                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3386             _mm256_add_ps(vop0, vbio));
3387         _mm_prefetch(
3388             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3389         vop8 = _mm256_fmadd_ps(
3390             vwgt,
3391             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3392                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3393             _mm256_add_ps(vop8, vbio));
3394         // skip unnecessary prefetch of (&ip_next_T0[8])
3395         vop16 = _mm256_fmadd_ps(
3396             vwgt,
3397             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3398                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3399             _mm256_add_ps(vop16, vbio));
3400         // skip unnecessary prefetch of (&ip_next_T0[16])
3401         vop24 = _mm256_fmadd_ps(
3402             vwgt,
3403             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3404                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3405             _mm256_add_ps(vop24, vbio));
3406         // skip unnecessary prefetch of (&ip_next_T0[24])
3407         vop32 = _mm256_fmadd_ps(
3408             vwgt,
3409             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3410                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3411             _mm256_add_ps(vop32, vbio));
3412         // skip unnecessary prefetch of (&ip_next_T0[32])
3413         vop40 = _mm256_fmadd_ps(
3414             vwgt,
3415             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3416                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3417             _mm256_add_ps(vop40, vbio));
3418         // skip unnecessary prefetch of (&ip_next_T0[40])
3419         vop48 = _mm256_fmadd_ps(
3420             vwgt,
3421             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3422                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3423             _mm256_add_ps(vop48, vbio));
3424         // skip unnecessary prefetch of (&ip_next_T0[48])
3425         vop56 = _mm256_fmadd_ps(
3426             vwgt,
3427             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3428                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3429             _mm256_add_ps(vop56, vbio));
3430         // skip unnecessary prefetch of (&ip_next_T0[56])
3431         vop64 = _mm256_fmadd_ps(
3432             vwgt,
3433             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3434                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
3435             _mm256_add_ps(vop64, vbio));
3436         _mm_prefetch(
3437             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
3438         vop72 = _mm256_fmadd_ps(
3439             vwgt,
3440             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3441                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
3442             _mm256_add_ps(vop72, vbio));
3443         // skip unnecessary prefetch of (&ip_next_T0[72])
3444         vop80 = _mm256_fmadd_ps(
3445             vwgt,
3446             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3447                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
3448             _mm256_add_ps(vop80, vbio));
3449         // skip unnecessary prefetch of (&ip_next_T0[80])
3450         vop88 = _mm256_fmadd_ps(
3451             vwgt,
3452             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3453                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
3454             _mm256_add_ps(vop88, vbio));
3455         // skip unnecessary prefetch of (&ip_next_T0[88])
3456         vop96 = _mm256_fmadd_ps(
3457             vwgt,
3458             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3459                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
3460             _mm256_add_ps(vop96, vbio));
3461         // skip unnecessary prefetch of (&ip_next_T0[96])
3462         vop104 = _mm256_fmadd_ps(
3463             vwgt,
3464             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3465                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
3466             _mm256_add_ps(vop104, vbio));
3467         // skip unnecessary prefetch of (&ip_next_T0[104])
3468         vop112 = _mm256_fmadd_ps(
3469             vwgt,
3470             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3471                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
3472             _mm256_add_ps(vop112, vbio));
3473         // skip unnecessary prefetch of (&ip_next_T0[112])
3474         vop120 = _mm256_fmadd_ps(
3475             vwgt,
3476             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3477                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
3478             _mm256_add_ps(vop120, vbio));
3479         // skip unnecessary prefetch of (&ip_next_T0[120])
3480       }
3481       if (!normalize_by_lengths || length == 0) {
3482         _mm256_storeu_ps(&op[0], vop0);
3483         _mm256_storeu_ps(&op[8], vop8);
3484         _mm256_storeu_ps(&op[16], vop16);
3485         _mm256_storeu_ps(&op[24], vop24);
3486         _mm256_storeu_ps(&op[32], vop32);
3487         _mm256_storeu_ps(&op[40], vop40);
3488         _mm256_storeu_ps(&op[48], vop48);
3489         _mm256_storeu_ps(&op[56], vop56);
3490         _mm256_storeu_ps(&op[64], vop64);
3491         _mm256_storeu_ps(&op[72], vop72);
3492         _mm256_storeu_ps(&op[80], vop80);
3493         _mm256_storeu_ps(&op[88], vop88);
3494         _mm256_storeu_ps(&op[96], vop96);
3495         _mm256_storeu_ps(&op[104], vop104);
3496         _mm256_storeu_ps(&op[112], vop112);
3497         _mm256_storeu_ps(&op[120], vop120);
3498       } else {
3499         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3500         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3501         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3502         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3503         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3504         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3505         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3506         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3507         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3508         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
3509         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
3510         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
3511         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
3512         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
3513         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
3514         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
3515         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
3516       }
3517     }
3518   } else if (block_size == 64) {
3519     // unrolling 8 times
3520     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3521       float* op = &out[rangeIndex * block_size];
3522       __m256 vop0 = _mm256_setzero_ps();
3523       __m256 vop8 = _mm256_setzero_ps();
3524       __m256 vop16 = _mm256_setzero_ps();
3525       __m256 vop24 = _mm256_setzero_ps();
3526       __m256 vop32 = _mm256_setzero_ps();
3527       __m256 vop40 = _mm256_setzero_ps();
3528       __m256 vop48 = _mm256_setzero_ps();
3529       __m256 vop56 = _mm256_setzero_ps();
3530       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3531         return false;
3532       }
3533       int64_t end_offset = offsets[rangeIndex + 1];
3534       int64_t length = end_offset - offsets[rangeIndex];
3535       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3536            ++dataInd) {
3537         const int idx = indices[dataInd];
3538         if (idx < 0 || idx >= data_size) {
3539           return false;
3540         }
3541         float wgt = 1.f;
3542         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3543         float bio;
3544         if (weights) {
3545           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3546         }
3547         bio = wgt * scale_bias[2 * idx + 1];
3548         wgt = wgt * scale_bias[2 * idx];
3549         __m256 vbio = _mm256_set1_ps(bio);
3550         __m256 vwgt = _mm256_set1_ps(wgt);
3551         const uint8_t* ip = &input[idx * fused_block_size];
3552         const int next_T0 = (dataInd < index_size - prefdist_T0)
3553             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3554             ? (dataInd + prefdist_T0)
3555             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3556             : dataInd;
3557         const int idx_pref_T0 = indices[next_T0];
3558         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3559           return false;
3560         }
3561         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3562         vop0 = _mm256_fmadd_ps(
3563             vwgt,
3564             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3565                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3566             _mm256_add_ps(vop0, vbio));
3567         _mm_prefetch(
3568             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3569         vop8 = _mm256_fmadd_ps(
3570             vwgt,
3571             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3572                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3573             _mm256_add_ps(vop8, vbio));
3574         // skip unnecessary prefetch of (&ip_next_T0[8])
3575         vop16 = _mm256_fmadd_ps(
3576             vwgt,
3577             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3578                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3579             _mm256_add_ps(vop16, vbio));
3580         // skip unnecessary prefetch of (&ip_next_T0[16])
3581         vop24 = _mm256_fmadd_ps(
3582             vwgt,
3583             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3584                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3585             _mm256_add_ps(vop24, vbio));
3586         // skip unnecessary prefetch of (&ip_next_T0[24])
3587         vop32 = _mm256_fmadd_ps(
3588             vwgt,
3589             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3590                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3591             _mm256_add_ps(vop32, vbio));
3592         // skip unnecessary prefetch of (&ip_next_T0[32])
3593         vop40 = _mm256_fmadd_ps(
3594             vwgt,
3595             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3596                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3597             _mm256_add_ps(vop40, vbio));
3598         // skip unnecessary prefetch of (&ip_next_T0[40])
3599         vop48 = _mm256_fmadd_ps(
3600             vwgt,
3601             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3602                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3603             _mm256_add_ps(vop48, vbio));
3604         // skip unnecessary prefetch of (&ip_next_T0[48])
3605         vop56 = _mm256_fmadd_ps(
3606             vwgt,
3607             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3608                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3609             _mm256_add_ps(vop56, vbio));
3610         // skip unnecessary prefetch of (&ip_next_T0[56])
3611       }
3612       if (!normalize_by_lengths || length == 0) {
3613         _mm256_storeu_ps(&op[0], vop0);
3614         _mm256_storeu_ps(&op[8], vop8);
3615         _mm256_storeu_ps(&op[16], vop16);
3616         _mm256_storeu_ps(&op[24], vop24);
3617         _mm256_storeu_ps(&op[32], vop32);
3618         _mm256_storeu_ps(&op[40], vop40);
3619         _mm256_storeu_ps(&op[48], vop48);
3620         _mm256_storeu_ps(&op[56], vop56);
3621       } else {
3622         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3623         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3624         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3625         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3626         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3627         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3628         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3629         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3630         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3631       }
3632     }
3633   } else if (block_size == 32) {
3634     // unrolling 4 times
3635     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3636       float* op = &out[rangeIndex * block_size];
3637       __m256 vop0 = _mm256_setzero_ps();
3638       __m256 vop8 = _mm256_setzero_ps();
3639       __m256 vop16 = _mm256_setzero_ps();
3640       __m256 vop24 = _mm256_setzero_ps();
3641       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3642         return false;
3643       }
3644       int64_t end_offset = offsets[rangeIndex + 1];
3645       int64_t length = end_offset - offsets[rangeIndex];
3646       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3647            ++dataInd) {
3648         const int idx = indices[dataInd];
3649         if (idx < 0 || idx >= data_size) {
3650           return false;
3651         }
3652         float wgt = 1.f;
3653         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3654         float bio;
3655         if (weights) {
3656           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3657         }
3658         bio = wgt * scale_bias[2 * idx + 1];
3659         wgt = wgt * scale_bias[2 * idx];
3660         __m256 vbio = _mm256_set1_ps(bio);
3661         __m256 vwgt = _mm256_set1_ps(wgt);
3662         const uint8_t* ip = &input[idx * fused_block_size];
3663         const int next_T0 = (dataInd < index_size - prefdist_T0)
3664             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3665             ? (dataInd + prefdist_T0)
3666             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3667             : dataInd;
3668         const int idx_pref_T0 = indices[next_T0];
3669         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3670           return false;
3671         }
3672         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3673         vop0 = _mm256_fmadd_ps(
3674             vwgt,
3675             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3676                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3677             _mm256_add_ps(vop0, vbio));
3678         _mm_prefetch(
3679             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3680         vop8 = _mm256_fmadd_ps(
3681             vwgt,
3682             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3683                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3684             _mm256_add_ps(vop8, vbio));
3685         // skip unnecessary prefetch of (&ip_next_T0[8])
3686         vop16 = _mm256_fmadd_ps(
3687             vwgt,
3688             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3689                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3690             _mm256_add_ps(vop16, vbio));
3691         // skip unnecessary prefetch of (&ip_next_T0[16])
3692         vop24 = _mm256_fmadd_ps(
3693             vwgt,
3694             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3695                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3696             _mm256_add_ps(vop24, vbio));
3697         // skip unnecessary prefetch of (&ip_next_T0[24])
3698       }
3699       if (!normalize_by_lengths || length == 0) {
3700         _mm256_storeu_ps(&op[0], vop0);
3701         _mm256_storeu_ps(&op[8], vop8);
3702         _mm256_storeu_ps(&op[16], vop16);
3703         _mm256_storeu_ps(&op[24], vop24);
3704       } else {
3705         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3706         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3707         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3708         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3709         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3710       }
3711     }
3712   } else if (block_size == 16) {
3713     // unrolling 2 times
3714     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3715       float* op = &out[rangeIndex * block_size];
3716       __m256 vop0 = _mm256_setzero_ps();
3717       __m256 vop8 = _mm256_setzero_ps();
3718       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3719         return false;
3720       }
3721       int64_t end_offset = offsets[rangeIndex + 1];
3722       int64_t length = end_offset - offsets[rangeIndex];
3723       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3724            ++dataInd) {
3725         const int idx = indices[dataInd];
3726         if (idx < 0 || idx >= data_size) {
3727           return false;
3728         }
3729         float wgt = 1.f;
3730         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3731         float bio;
3732         if (weights) {
3733           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3734         }
3735         bio = wgt * scale_bias[2 * idx + 1];
3736         wgt = wgt * scale_bias[2 * idx];
3737         __m256 vbio = _mm256_set1_ps(bio);
3738         __m256 vwgt = _mm256_set1_ps(wgt);
3739         const uint8_t* ip = &input[idx * fused_block_size];
3740         const int next_T0 = (dataInd < index_size - prefdist_T0)
3741             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3742             ? (dataInd + prefdist_T0)
3743             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3744             : dataInd;
3745         const int idx_pref_T0 = indices[next_T0];
3746         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3747           return false;
3748         }
3749         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3750         vop0 = _mm256_fmadd_ps(
3751             vwgt,
3752             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3753                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3754             _mm256_add_ps(vop0, vbio));
3755         _mm_prefetch(
3756             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3757         vop8 = _mm256_fmadd_ps(
3758             vwgt,
3759             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3760                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3761             _mm256_add_ps(vop8, vbio));
3762         // skip unnecessary prefetch of (&ip_next_T0[8])
3763       }
3764       if (!normalize_by_lengths || length == 0) {
3765         _mm256_storeu_ps(&op[0], vop0);
3766         _mm256_storeu_ps(&op[8], vop8);
3767       } else {
3768         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3769         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3770         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3771       }
3772     }
3773   } else {
3774     // generic code
3775     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3776     for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3777       float* op = &out[rangeIndex * block_size];
3778       int64_t j = 0;
3779       for (; j + 8 <= block_size; j += 8) {
3780         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3781       }
3782       for (; j < block_size; j++) {
3783         op[j] = 0.0f;
3784       }
3785       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3786         return false;
3787       }
3788       int64_t end_offset = offsets[rangeIndex + 1];
3789       int64_t length = end_offset - offsets[rangeIndex];
3790       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3791            ++dataInd) {
3792         const int idx = indices[dataInd];
3793         if (idx < 0 || idx >= data_size) {
3794           return false;
3795         }
3796         float wgt = 1.f;
3797         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3798         float bio;
3799         if (weights) {
3800           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3801         }
3802         bio = wgt * scale_bias[2 * idx + 1];
3803         wgt = wgt * scale_bias[2 * idx];
3804         __m256 vbio = _mm256_set1_ps(bio);
3805         __m256 vwgt = _mm256_set1_ps(wgt);
3806         const uint8_t* ip = &input[idx * fused_block_size];
3807         const int next_T0 = (dataInd < index_size - prefdist_T0)
3808             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3809             ? (dataInd + prefdist_T0)
3810             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3811             : dataInd;
3812         const int idx_pref_T0 = indices[next_T0];
3813         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3814           return false;
3815         }
3816         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3817         j = 0;
3818         for (; j + 8 <= block_size; j += 8) {
3819           _mm256_storeu_ps(
3820               &op[j],
3821               _mm256_fmadd_ps(
3822                   vwgt,
3823                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
3824                       reinterpret_cast<const __m128i*>(&ip[j])))),
3825                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
3826           _mm_prefetch(
3827               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3828         }
3829         for (; j < block_size; j++) {
3830           op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
3831         }
3832       }
3833       if (normalize_by_lengths && length) {
3834         float len_inv = 1.0f / length;
3835         __m256 vlen_inv = _mm256_set1_ps(len_inv);
3836         j = 0;
3837         for (; j + 8 <= block_size; j += 8) {
3838           _mm256_storeu_ps(
3839               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3840         }
3841         for (; j < block_size; j++) {
3842           op[j] = len_inv * op[j];
3843         }
3844       }
3845     }
3846   }
3847   return dataInd == index_size;
3848 }
EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3849 bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
3850     const int64_t block_size,
3851     const int64_t output_size,
3852     const int64_t index_size,
3853     const int64_t data_size,
3854     const uint8_t* input,
3855     const int* indices,
3856     const int* offsets,
3857     const float* weights,
3858     const float* scale_bias,
3859     bool normalize_by_lengths,
3860     float* out) {
3861   return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<false>(
3862       block_size,
3863       output_size,
3864       index_size,
3865       data_size,
3866       input,
3867       indices,
3868       offsets,
3869       weights,
3870       scale_bias,
3871       normalize_by_lengths,
3872       out);
3873 }
EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int * indices,const int * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3874 bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
3875     const int64_t block_size,
3876     const int64_t output_size,
3877     const int64_t index_size,
3878     const int64_t data_size,
3879     const uint8_t* input,
3880     const int* indices,
3881     const int* offsets,
3882     const float* weights,
3883     const float* scale_bias,
3884     bool normalize_by_lengths,
3885     float* out) {
3886   return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<true>(
3887       block_size,
3888       output_size,
3889       index_size,
3890       data_size,
3891       input,
3892       indices,
3893       offsets,
3894       weights,
3895       scale_bias,
3896       normalize_by_lengths,
3897       out);
3898 }
3899 
3900 template <bool IS_WEIGHT_POSITIONAL>
EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)3901 static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(
3902     const int64_t block_size,
3903     const int64_t output_size,
3904     const int64_t index_size,
3905     const int64_t data_size,
3906     const uint8_t* input,
3907     const int64_t* indices,
3908     const int64_t* offsets,
3909     const float* weights,
3910     const float* scale_bias,
3911     bool normalize_by_lengths,
3912     float* out) {
3913   const int64_t prefdist_T0 = 16;
3914   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3915   const int64_t fused_block_size = block_size + 0;
3916   int64_t dataInd = 0;
3917   if (block_size == 128) {
3918     // unrolling 16 times
3919     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3920       float* op = &out[rangeIndex * block_size];
3921       __m256 vop0 = _mm256_setzero_ps();
3922       __m256 vop8 = _mm256_setzero_ps();
3923       __m256 vop16 = _mm256_setzero_ps();
3924       __m256 vop24 = _mm256_setzero_ps();
3925       __m256 vop32 = _mm256_setzero_ps();
3926       __m256 vop40 = _mm256_setzero_ps();
3927       __m256 vop48 = _mm256_setzero_ps();
3928       __m256 vop56 = _mm256_setzero_ps();
3929       __m256 vop64 = _mm256_setzero_ps();
3930       __m256 vop72 = _mm256_setzero_ps();
3931       __m256 vop80 = _mm256_setzero_ps();
3932       __m256 vop88 = _mm256_setzero_ps();
3933       __m256 vop96 = _mm256_setzero_ps();
3934       __m256 vop104 = _mm256_setzero_ps();
3935       __m256 vop112 = _mm256_setzero_ps();
3936       __m256 vop120 = _mm256_setzero_ps();
3937       if (dataInd != offsets[rangeIndex] - offsets[0]) {
3938         return false;
3939       }
3940       int64_t end_offset = offsets[rangeIndex + 1];
3941       int64_t length = end_offset - offsets[rangeIndex];
3942       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3943            ++dataInd) {
3944         const int64_t idx = indices[dataInd];
3945         if (idx < 0 || idx >= data_size) {
3946           return false;
3947         }
3948         float wgt = 1.f;
3949         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3950         float bio;
3951         if (weights) {
3952           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3953         }
3954         bio = wgt * scale_bias[2 * idx + 1];
3955         wgt = wgt * scale_bias[2 * idx];
3956         __m256 vbio = _mm256_set1_ps(bio);
3957         __m256 vwgt = _mm256_set1_ps(wgt);
3958         const uint8_t* ip = &input[idx * fused_block_size];
3959         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3960             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3961             ? (dataInd + prefdist_T0)
3962             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3963             : dataInd;
3964         const int64_t idx_pref_T0 = indices[next_T0];
3965         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3966           return false;
3967         }
3968         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3969         vop0 = _mm256_fmadd_ps(
3970             vwgt,
3971             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3972                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3973             _mm256_add_ps(vop0, vbio));
3974         _mm_prefetch(
3975             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3976         vop8 = _mm256_fmadd_ps(
3977             vwgt,
3978             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3979                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3980             _mm256_add_ps(vop8, vbio));
3981         // skip unnecessary prefetch of (&ip_next_T0[8])
3982         vop16 = _mm256_fmadd_ps(
3983             vwgt,
3984             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3985                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3986             _mm256_add_ps(vop16, vbio));
3987         // skip unnecessary prefetch of (&ip_next_T0[16])
3988         vop24 = _mm256_fmadd_ps(
3989             vwgt,
3990             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3991                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3992             _mm256_add_ps(vop24, vbio));
3993         // skip unnecessary prefetch of (&ip_next_T0[24])
3994         vop32 = _mm256_fmadd_ps(
3995             vwgt,
3996             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3997                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3998             _mm256_add_ps(vop32, vbio));
3999         // skip unnecessary prefetch of (&ip_next_T0[32])
4000         vop40 = _mm256_fmadd_ps(
4001             vwgt,
4002             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4003                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4004             _mm256_add_ps(vop40, vbio));
4005         // skip unnecessary prefetch of (&ip_next_T0[40])
4006         vop48 = _mm256_fmadd_ps(
4007             vwgt,
4008             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4009                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4010             _mm256_add_ps(vop48, vbio));
4011         // skip unnecessary prefetch of (&ip_next_T0[48])
4012         vop56 = _mm256_fmadd_ps(
4013             vwgt,
4014             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4015                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4016             _mm256_add_ps(vop56, vbio));
4017         // skip unnecessary prefetch of (&ip_next_T0[56])
4018         vop64 = _mm256_fmadd_ps(
4019             vwgt,
4020             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4021                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
4022             _mm256_add_ps(vop64, vbio));
4023         _mm_prefetch(
4024             reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
4025         vop72 = _mm256_fmadd_ps(
4026             vwgt,
4027             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4028                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
4029             _mm256_add_ps(vop72, vbio));
4030         // skip unnecessary prefetch of (&ip_next_T0[72])
4031         vop80 = _mm256_fmadd_ps(
4032             vwgt,
4033             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4034                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
4035             _mm256_add_ps(vop80, vbio));
4036         // skip unnecessary prefetch of (&ip_next_T0[80])
4037         vop88 = _mm256_fmadd_ps(
4038             vwgt,
4039             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4040                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
4041             _mm256_add_ps(vop88, vbio));
4042         // skip unnecessary prefetch of (&ip_next_T0[88])
4043         vop96 = _mm256_fmadd_ps(
4044             vwgt,
4045             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4046                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
4047             _mm256_add_ps(vop96, vbio));
4048         // skip unnecessary prefetch of (&ip_next_T0[96])
4049         vop104 = _mm256_fmadd_ps(
4050             vwgt,
4051             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4052                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
4053             _mm256_add_ps(vop104, vbio));
4054         // skip unnecessary prefetch of (&ip_next_T0[104])
4055         vop112 = _mm256_fmadd_ps(
4056             vwgt,
4057             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4058                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
4059             _mm256_add_ps(vop112, vbio));
4060         // skip unnecessary prefetch of (&ip_next_T0[112])
4061         vop120 = _mm256_fmadd_ps(
4062             vwgt,
4063             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4064                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
4065             _mm256_add_ps(vop120, vbio));
4066         // skip unnecessary prefetch of (&ip_next_T0[120])
4067       }
4068       if (!normalize_by_lengths || length == 0) {
4069         _mm256_storeu_ps(&op[0], vop0);
4070         _mm256_storeu_ps(&op[8], vop8);
4071         _mm256_storeu_ps(&op[16], vop16);
4072         _mm256_storeu_ps(&op[24], vop24);
4073         _mm256_storeu_ps(&op[32], vop32);
4074         _mm256_storeu_ps(&op[40], vop40);
4075         _mm256_storeu_ps(&op[48], vop48);
4076         _mm256_storeu_ps(&op[56], vop56);
4077         _mm256_storeu_ps(&op[64], vop64);
4078         _mm256_storeu_ps(&op[72], vop72);
4079         _mm256_storeu_ps(&op[80], vop80);
4080         _mm256_storeu_ps(&op[88], vop88);
4081         _mm256_storeu_ps(&op[96], vop96);
4082         _mm256_storeu_ps(&op[104], vop104);
4083         _mm256_storeu_ps(&op[112], vop112);
4084         _mm256_storeu_ps(&op[120], vop120);
4085       } else {
4086         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4087         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4088         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4089         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4090         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4091         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4092         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4093         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4094         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4095         _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
4096         _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
4097         _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
4098         _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
4099         _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
4100         _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
4101         _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
4102         _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
4103       }
4104     }
4105   } else if (block_size == 64) {
4106     // unrolling 8 times
4107     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4108       float* op = &out[rangeIndex * block_size];
4109       __m256 vop0 = _mm256_setzero_ps();
4110       __m256 vop8 = _mm256_setzero_ps();
4111       __m256 vop16 = _mm256_setzero_ps();
4112       __m256 vop24 = _mm256_setzero_ps();
4113       __m256 vop32 = _mm256_setzero_ps();
4114       __m256 vop40 = _mm256_setzero_ps();
4115       __m256 vop48 = _mm256_setzero_ps();
4116       __m256 vop56 = _mm256_setzero_ps();
4117       if (dataInd != offsets[rangeIndex] - offsets[0]) {
4118         return false;
4119       }
4120       int64_t end_offset = offsets[rangeIndex + 1];
4121       int64_t length = end_offset - offsets[rangeIndex];
4122       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4123            ++dataInd) {
4124         const int64_t idx = indices[dataInd];
4125         if (idx < 0 || idx >= data_size) {
4126           return false;
4127         }
4128         float wgt = 1.f;
4129         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4130         float bio;
4131         if (weights) {
4132           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4133         }
4134         bio = wgt * scale_bias[2 * idx + 1];
4135         wgt = wgt * scale_bias[2 * idx];
4136         __m256 vbio = _mm256_set1_ps(bio);
4137         __m256 vwgt = _mm256_set1_ps(wgt);
4138         const uint8_t* ip = &input[idx * fused_block_size];
4139         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4140             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4141             ? (dataInd + prefdist_T0)
4142             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4143             : dataInd;
4144         const int64_t idx_pref_T0 = indices[next_T0];
4145         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4146           return false;
4147         }
4148         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4149         vop0 = _mm256_fmadd_ps(
4150             vwgt,
4151             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4152                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4153             _mm256_add_ps(vop0, vbio));
4154         _mm_prefetch(
4155             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4156         vop8 = _mm256_fmadd_ps(
4157             vwgt,
4158             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4159                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4160             _mm256_add_ps(vop8, vbio));
4161         // skip unnecessary prefetch of (&ip_next_T0[8])
4162         vop16 = _mm256_fmadd_ps(
4163             vwgt,
4164             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4165                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4166             _mm256_add_ps(vop16, vbio));
4167         // skip unnecessary prefetch of (&ip_next_T0[16])
4168         vop24 = _mm256_fmadd_ps(
4169             vwgt,
4170             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4171                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4172             _mm256_add_ps(vop24, vbio));
4173         // skip unnecessary prefetch of (&ip_next_T0[24])
4174         vop32 = _mm256_fmadd_ps(
4175             vwgt,
4176             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4177                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
4178             _mm256_add_ps(vop32, vbio));
4179         // skip unnecessary prefetch of (&ip_next_T0[32])
4180         vop40 = _mm256_fmadd_ps(
4181             vwgt,
4182             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4183                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4184             _mm256_add_ps(vop40, vbio));
4185         // skip unnecessary prefetch of (&ip_next_T0[40])
4186         vop48 = _mm256_fmadd_ps(
4187             vwgt,
4188             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4189                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4190             _mm256_add_ps(vop48, vbio));
4191         // skip unnecessary prefetch of (&ip_next_T0[48])
4192         vop56 = _mm256_fmadd_ps(
4193             vwgt,
4194             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4195                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4196             _mm256_add_ps(vop56, vbio));
4197         // skip unnecessary prefetch of (&ip_next_T0[56])
4198       }
4199       if (!normalize_by_lengths || length == 0) {
4200         _mm256_storeu_ps(&op[0], vop0);
4201         _mm256_storeu_ps(&op[8], vop8);
4202         _mm256_storeu_ps(&op[16], vop16);
4203         _mm256_storeu_ps(&op[24], vop24);
4204         _mm256_storeu_ps(&op[32], vop32);
4205         _mm256_storeu_ps(&op[40], vop40);
4206         _mm256_storeu_ps(&op[48], vop48);
4207         _mm256_storeu_ps(&op[56], vop56);
4208       } else {
4209         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4210         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4211         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4212         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4213         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4214         _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4215         _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4216         _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4217         _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4218       }
4219     }
4220   } else if (block_size == 32) {
4221     // unrolling 4 times
4222     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4223       float* op = &out[rangeIndex * block_size];
4224       __m256 vop0 = _mm256_setzero_ps();
4225       __m256 vop8 = _mm256_setzero_ps();
4226       __m256 vop16 = _mm256_setzero_ps();
4227       __m256 vop24 = _mm256_setzero_ps();
4228       if (dataInd != offsets[rangeIndex] - offsets[0]) {
4229         return false;
4230       }
4231       int64_t end_offset = offsets[rangeIndex + 1];
4232       int64_t length = end_offset - offsets[rangeIndex];
4233       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4234            ++dataInd) {
4235         const int64_t idx = indices[dataInd];
4236         if (idx < 0 || idx >= data_size) {
4237           return false;
4238         }
4239         float wgt = 1.f;
4240         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4241         float bio;
4242         if (weights) {
4243           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4244         }
4245         bio = wgt * scale_bias[2 * idx + 1];
4246         wgt = wgt * scale_bias[2 * idx];
4247         __m256 vbio = _mm256_set1_ps(bio);
4248         __m256 vwgt = _mm256_set1_ps(wgt);
4249         const uint8_t* ip = &input[idx * fused_block_size];
4250         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4251             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4252             ? (dataInd + prefdist_T0)
4253             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4254             : dataInd;
4255         const int64_t idx_pref_T0 = indices[next_T0];
4256         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4257           return false;
4258         }
4259         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4260         vop0 = _mm256_fmadd_ps(
4261             vwgt,
4262             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4263                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4264             _mm256_add_ps(vop0, vbio));
4265         _mm_prefetch(
4266             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4267         vop8 = _mm256_fmadd_ps(
4268             vwgt,
4269             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4270                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4271             _mm256_add_ps(vop8, vbio));
4272         // skip unnecessary prefetch of (&ip_next_T0[8])
4273         vop16 = _mm256_fmadd_ps(
4274             vwgt,
4275             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4276                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4277             _mm256_add_ps(vop16, vbio));
4278         // skip unnecessary prefetch of (&ip_next_T0[16])
4279         vop24 = _mm256_fmadd_ps(
4280             vwgt,
4281             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4282                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4283             _mm256_add_ps(vop24, vbio));
4284         // skip unnecessary prefetch of (&ip_next_T0[24])
4285       }
4286       if (!normalize_by_lengths || length == 0) {
4287         _mm256_storeu_ps(&op[0], vop0);
4288         _mm256_storeu_ps(&op[8], vop8);
4289         _mm256_storeu_ps(&op[16], vop16);
4290         _mm256_storeu_ps(&op[24], vop24);
4291       } else {
4292         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4293         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4294         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4295         _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4296         _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4297       }
4298     }
4299   } else if (block_size == 16) {
4300     // unrolling 2 times
4301     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4302       float* op = &out[rangeIndex * block_size];
4303       __m256 vop0 = _mm256_setzero_ps();
4304       __m256 vop8 = _mm256_setzero_ps();
4305       if (dataInd != offsets[rangeIndex] - offsets[0]) {
4306         return false;
4307       }
4308       int64_t end_offset = offsets[rangeIndex + 1];
4309       int64_t length = end_offset - offsets[rangeIndex];
4310       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4311            ++dataInd) {
4312         const int64_t idx = indices[dataInd];
4313         if (idx < 0 || idx >= data_size) {
4314           return false;
4315         }
4316         float wgt = 1.f;
4317         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4318         float bio;
4319         if (weights) {
4320           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4321         }
4322         bio = wgt * scale_bias[2 * idx + 1];
4323         wgt = wgt * scale_bias[2 * idx];
4324         __m256 vbio = _mm256_set1_ps(bio);
4325         __m256 vwgt = _mm256_set1_ps(wgt);
4326         const uint8_t* ip = &input[idx * fused_block_size];
4327         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4328             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4329             ? (dataInd + prefdist_T0)
4330             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4331             : dataInd;
4332         const int64_t idx_pref_T0 = indices[next_T0];
4333         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4334           return false;
4335         }
4336         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4337         vop0 = _mm256_fmadd_ps(
4338             vwgt,
4339             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4340                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4341             _mm256_add_ps(vop0, vbio));
4342         _mm_prefetch(
4343             reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4344         vop8 = _mm256_fmadd_ps(
4345             vwgt,
4346             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4347                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4348             _mm256_add_ps(vop8, vbio));
4349         // skip unnecessary prefetch of (&ip_next_T0[8])
4350       }
4351       if (!normalize_by_lengths || length == 0) {
4352         _mm256_storeu_ps(&op[0], vop0);
4353         _mm256_storeu_ps(&op[8], vop8);
4354       } else {
4355         __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4356         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4357         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4358       }
4359     }
4360   } else {
4361     // generic code
4362     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
4363     for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4364       float* op = &out[rangeIndex * block_size];
4365       int64_t j = 0;
4366       for (; j + 8 <= block_size; j += 8) {
4367         _mm256_storeu_ps(op + j, _mm256_setzero_ps());
4368       }
4369       for (; j < block_size; j++) {
4370         op[j] = 0.0f;
4371       }
4372       if (dataInd != offsets[rangeIndex] - offsets[0]) {
4373         return false;
4374       }
4375       int64_t end_offset = offsets[rangeIndex + 1];
4376       int64_t length = end_offset - offsets[rangeIndex];
4377       for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4378            ++dataInd) {
4379         const int64_t idx = indices[dataInd];
4380         if (idx < 0 || idx >= data_size) {
4381           return false;
4382         }
4383         float wgt = 1.f;
4384         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4385         float bio;
4386         if (weights) {
4387           wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4388         }
4389         bio = wgt * scale_bias[2 * idx + 1];
4390         wgt = wgt * scale_bias[2 * idx];
4391         __m256 vbio = _mm256_set1_ps(bio);
4392         __m256 vwgt = _mm256_set1_ps(wgt);
4393         const uint8_t* ip = &input[idx * fused_block_size];
4394         const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4395             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4396             ? (dataInd + prefdist_T0)
4397             // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4398             : dataInd;
4399         const int64_t idx_pref_T0 = indices[next_T0];
4400         if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4401           return false;
4402         }
4403         const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4404         j = 0;
4405         for (; j + 8 <= block_size; j += 8) {
4406           _mm256_storeu_ps(
4407               &op[j],
4408               _mm256_fmadd_ps(
4409                   vwgt,
4410                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
4411                       reinterpret_cast<const __m128i*>(&ip[j])))),
4412                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
4413           _mm_prefetch(
4414               reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
4415         }
4416         for (; j < block_size; j++) {
4417           op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
4418         }
4419       }
4420       if (normalize_by_lengths && length) {
4421         float len_inv = 1.0f / length;
4422         __m256 vlen_inv = _mm256_set1_ps(len_inv);
4423         j = 0;
4424         for (; j + 8 <= block_size; j += 8) {
4425           _mm256_storeu_ps(
4426               &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
4427         }
4428         for (; j < block_size; j++) {
4429           op[j] = len_inv * op[j];
4430         }
4431       }
4432     }
4433   }
4434   return dataInd == index_size;
4435 }
EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)4436 bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(
4437     const int64_t block_size,
4438     const int64_t output_size,
4439     const int64_t index_size,
4440     const int64_t data_size,
4441     const uint8_t* input,
4442     const int64_t* indices,
4443     const int64_t* offsets,
4444     const float* weights,
4445     const float* scale_bias,
4446     bool normalize_by_lengths,
4447     float* out) {
4448   return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<false>(
4449       block_size,
4450       output_size,
4451       index_size,
4452       data_size,
4453       input,
4454       indices,
4455       offsets,
4456       weights,
4457       scale_bias,
4458       normalize_by_lengths,
4459       out);
4460 }
EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const uint8_t * input,const int64_t * indices,const int64_t * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,float * out)4461 bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(
4462     const int64_t block_size,
4463     const int64_t output_size,
4464     const int64_t index_size,
4465     const int64_t data_size,
4466     const uint8_t* input,
4467     const int64_t* indices,
4468     const int64_t* offsets,
4469     const float* weights,
4470     const float* scale_bias,
4471     bool normalize_by_lengths,
4472     float* out) {
4473   return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<true>(
4474       block_size,
4475       output_size,
4476       index_size,
4477       data_size,
4478       input,
4479       indices,
4480       offsets,
4481       weights,
4482       scale_bias,
4483       normalize_by_lengths,
4484       out);
4485 }
4486 
4487 } // namespace caffe2
4488