xref: /aosp_15_r20/external/XNNPACK/src/qc8-dwconv/gen/up32x3-minmax-fp32-avx512skx-mul32.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-dwconv/unipass-avx512skx-mul32.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <immintrin.h>
13 
14 #include <xnnpack/dwconv.h>
15 #include <xnnpack/intrinsics-polyfill.h>
16 
17 
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x3__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x3__avx512skx_mul32(
19     size_t channels,
20     size_t output_width,
21     const int8_t** input,
22     const void* weights,
23     int8_t* output,
24     size_t input_stride,
25     size_t output_increment,
26     size_t input_offset,
27     const int8_t* zero,
28     const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
29 {
30   assert(channels != 0);
31   assert(output_width != 0);
32 
33   const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
34   const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
35   const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
36   const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
37 
38   do {
39     const int8_t* i0 = input[0];
40     assert(i0 != NULL);
41     if XNN_UNPREDICTABLE(i0 != zero) {
42       i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
43     }
44     const int8_t* i1 = input[1];
45     assert(i1 != NULL);
46     if XNN_UNPREDICTABLE(i1 != zero) {
47       i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
48     }
49     const int8_t* i2 = input[2];
50     assert(i2 != NULL);
51     if XNN_UNPREDICTABLE(i2 != zero) {
52       i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
53     }
54     input = (const int8_t**) ((uintptr_t) input + input_stride);
55 
56     size_t c = channels;
57     const void* w = weights;
58     for (; c >= 32; c -= 32) {
59       __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
60       __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
61 
62 
63       const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
64       const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
65       const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
66       const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
67       i0 += 32;
68 
69       vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
70       vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
71 
72       const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
73       const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
74       const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
75       const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
76       i1 += 32;
77 
78       vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
79       vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
80 
81       const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
82       const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
83       const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
84       const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
85       i2 += 32;
86 
87       vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
88       vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
89 
90       w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t));
91 
92       __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
93       __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
94 
95       const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
96       const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
97       w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
98       vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
99       vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
100 
101       vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
102       vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
103 
104       vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
105       vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
106 
107       __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
108       __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
109 
110       const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
111       const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
112       const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
113       __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
114       const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
115       const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
116       __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
117 
118       vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
119       voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
120 
121       _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
122       _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
123       output += 32;
124     }
125     if XNN_UNLIKELY(c != 0) {
126       // Prepare mask for valid 8-bit elements (depends on nc).
127       const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
128       const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
129       do {
130         __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
131 
132 
133         const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
134         const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
135         i0 += 16;
136 
137         vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
138 
139         const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
140         const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
141         i1 += 16;
142 
143         vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
144 
145         const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
146         const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
147         i2 += 16;
148 
149         vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
150 
151         k += 16;
152 
153         __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
154         const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t)));
155         vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
156         vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
157         vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
158 
159         w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
160 
161         __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
162 
163         const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
164         const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
165         __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
166         vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
167 
168         if XNN_LIKELY(c >= 16) {
169           _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
170           output += 16;
171           c -= 16;
172         } else {
173           _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
174           output = (int8_t*) ((uintptr_t) output + c);
175           c = 0;
176         }
177       } while (c != 0);
178     }
179 
180     output = (int8_t*) ((uintptr_t) output + output_increment);
181   } while (--output_width != 0);
182 }
183