1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <immintrin.h>
10
11 #include <qnnpack/common.h>
12 #include <qnnpack/q8vadd.h>
13 #include <qnnpack/scalar-utils.h>
14
pytorch_q8vadd_ukernel__sse2(size_t n,const uint8_t * a,const uint8_t * b,uint8_t * y,const union pytorch_qnnp_add_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8vadd_ukernel__sse2(
16 size_t n,
17 const uint8_t* a,
18 const uint8_t* b,
19 uint8_t* y,
20 const union pytorch_qnnp_add_quantization_params
21 quantization_params[RESTRICT_STATIC 1]) {
22 if
23 PYTORCH_QNNP_LIKELY(n >= 8) {
24 const __m128i vzero_point_product = _mm_load_si128(
25 (const __m128i*)&quantization_params->sse2.zero_point_product);
26 const __m128i va_multiplier_lo = _mm_load_si128(
27 (const __m128i*)&quantization_params->sse2.a_multiplier_lo);
28 const __m128i va_multiplier_hi = _mm_load_si128(
29 (const __m128i*)&quantization_params->sse2.a_multiplier_hi);
30 const __m128i vb_multiplier_lo = _mm_load_si128(
31 (const __m128i*)&quantization_params->sse2.b_multiplier_lo);
32 const __m128i vb_multiplier_hi = _mm_load_si128(
33 (const __m128i*)&quantization_params->sse2.b_multiplier_hi);
34 const __m128i vremainder_mask = _mm_load_si128(
35 (const __m128i*)quantization_params->sse2.remainder_mask);
36 const __m128i vremainder_threshold = _mm_load_si128(
37 (const __m128i*)quantization_params->sse2.remainder_threshold);
38 const __m128i vshift =
39 _mm_cvtsi32_si128((int)quantization_params->sse2.shift);
40
41 const __m128i vzero = _mm_setzero_si128();
42 do {
43 const __m128i va = _mm_loadl_epi64((const __m128i*)a);
44 a += 8;
45 const __m128i vb = _mm_loadl_epi64((const __m128i*)b);
46 b += 8;
47
48 const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
49 const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
50
51 /* Multiply by factors */
52 const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
53 const __m128i va_product_hi = _mm_add_epi16(
54 _mm_mulhi_epu16(vxa, va_multiplier_lo),
55 _mm_mullo_epi16(vxa, va_multiplier_hi));
56
57 const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
58 const __m128i vb_product_hi = _mm_add_epi16(
59 _mm_mulhi_epu16(vxb, vb_multiplier_lo),
60 _mm_mullo_epi16(vxb, vb_multiplier_hi));
61
62 /* Accumulate products */
63 __m128i vacc_lo = _mm_add_epi32(
64 vzero_point_product,
65 _mm_unpacklo_epi16(va_product_lo, va_product_hi));
66 __m128i vacc_hi = _mm_add_epi32(
67 vzero_point_product,
68 _mm_unpackhi_epi16(va_product_lo, va_product_hi));
69
70 vacc_lo = _mm_add_epi32(
71 vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
72 vacc_hi = _mm_add_epi32(
73 vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
74
75 /* Shift right and round */
76 const __m128i vrem_lo = _mm_add_epi32(
77 _mm_and_si128(vacc_lo, vremainder_mask),
78 _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
79 const __m128i vrem_hi = _mm_add_epi32(
80 _mm_and_si128(vacc_hi, vremainder_mask),
81 _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
82
83 vacc_lo = _mm_sub_epi32(
84 _mm_sra_epi32(vacc_lo, vshift),
85 _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
86 vacc_hi = _mm_sub_epi32(
87 _mm_sra_epi32(vacc_hi, vshift),
88 _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
89
90 /* Pack, saturate, and add output zero point */
91 const __m128i vy_zero_point = _mm_load_si128(
92 (const __m128i*)quantization_params->sse2.y_zero_point);
93 const __m128i vacc =
94 _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
95 __m128i vy = _mm_packus_epi16(vacc, vacc);
96 vy = _mm_max_epu8(
97 vy,
98 _mm_load_si128((const __m128i*)quantization_params->sse2.y_min));
99 vy = _mm_min_epu8(
100 vy,
101 _mm_load_si128((const __m128i*)quantization_params->sse2.y_max));
102
103 _mm_storel_epi64((__m128i*)y, vy);
104 y += 8;
105
106 n -= 8;
107 } while (n >= 8);
108 if (n != 0) {
109 const size_t n_decrement = 8 - n;
110 const __m128i vload_shift = _mm_cvtsi32_si128(8 * (int32_t)n_decrement);
111
112 const __m128i va = _mm_srl_epi64(
113 _mm_loadl_epi64((const __m128i*)(a - n_decrement)), vload_shift);
114 const __m128i vb = _mm_srl_epi64(
115 _mm_loadl_epi64((const __m128i*)(b - n_decrement)), vload_shift);
116
117 const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
118 const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
119
120 /* Multiply by factors */
121 const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
122 const __m128i va_product_hi = _mm_add_epi16(
123 _mm_mulhi_epu16(vxa, va_multiplier_lo),
124 _mm_mullo_epi16(vxa, va_multiplier_hi));
125
126 const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
127 const __m128i vb_product_hi = _mm_add_epi16(
128 _mm_mulhi_epu16(vxb, vb_multiplier_lo),
129 _mm_mullo_epi16(vxb, vb_multiplier_hi));
130
131 /* Accumulate products */
132 __m128i vacc_lo = _mm_add_epi32(
133 vzero_point_product,
134 _mm_unpacklo_epi16(va_product_lo, va_product_hi));
135 __m128i vacc_hi = _mm_add_epi32(
136 vzero_point_product,
137 _mm_unpackhi_epi16(va_product_lo, va_product_hi));
138
139 vacc_lo = _mm_add_epi32(
140 vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
141 vacc_hi = _mm_add_epi32(
142 vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
143
144 /* Shift right and round */
145 const __m128i vrem_lo = _mm_add_epi32(
146 _mm_and_si128(vacc_lo, vremainder_mask),
147 _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
148 const __m128i vrem_hi = _mm_add_epi32(
149 _mm_and_si128(vacc_hi, vremainder_mask),
150 _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
151
152 vacc_lo = _mm_sub_epi32(
153 _mm_sra_epi32(vacc_lo, vshift),
154 _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
155 vacc_hi = _mm_sub_epi32(
156 _mm_sra_epi32(vacc_hi, vshift),
157 _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
158
159 /* Pack, saturate, and add output zero point */
160 const __m128i vy_zero_point = _mm_load_si128(
161 (const __m128i*)quantization_params->sse2.y_zero_point);
162 const __m128i vacc =
163 _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
164 __m128i vy = _mm_packus_epi16(vacc, vacc);
165 vy = _mm_max_epu8(
166 vy,
167 _mm_load_si128((const __m128i*)quantization_params->sse2.y_min));
168 vy = _mm_min_epu8(
169 vy,
170 _mm_load_si128((const __m128i*)quantization_params->sse2.y_max));
171
172 if (n & 4) {
173 *((uint32_t*)y) = (uint32_t)_mm_cvtsi128_si32(vy);
174 vy = _mm_shuffle_epi32(vy, _MM_SHUFFLE(3, 2, 1, 1));
175 y += 4;
176 }
177 if (n & 2) {
178 *((uint16_t*)y) = (uint16_t)_mm_extract_epi16(vy, 0);
179 vy = _mm_srli_epi32(vy, 16);
180 y += 2;
181 }
182 if (n & 1) {
183 *((uint8_t*)y) = (uint8_t)_mm_cvtsi128_si32(vy);
184 }
185 }
186 }
187 else {
188 const int32_t vzero_point_product =
189 quantization_params->sse2.zero_point_product[0];
190 const uint32_t va_multiplier = quantization_params->sse2.a_multiplier;
191 const uint32_t vb_multiplier = quantization_params->sse2.b_multiplier;
192 const int32_t vremainder_mask = quantization_params->sse2.remainder_mask[0];
193 const int32_t vremainder_threshold =
194 quantization_params->sse2.remainder_threshold[0];
195 const uint32_t vshift = quantization_params->sse2.shift;
196 const int32_t vy_zero_point =
197 (int32_t)quantization_params->sse2.y_zero_point[0];
198 const int32_t vy_max =
199 (int32_t)(uint32_t)quantization_params->sse2.y_max[0];
200 const int32_t vy_min =
201 (int32_t)(uint32_t)quantization_params->sse2.y_min[0];
202
203 while (n-- != 0) {
204 const uint32_t vxa = (uint32_t)*a++;
205 const uint32_t vxb = (uint32_t)*b++;
206
207 /* Multiply by factors and accumulate products */
208 int32_t vacc = vzero_point_product + (int32_t)(vxa * va_multiplier) +
209 (int32_t)(vxb * vb_multiplier);
210
211 /* Shift right and round */
212 const int32_t vrem = (vacc & vremainder_mask) - (int32_t)(vacc < 0);
213
214 vacc = asr_s32(vacc, vshift) + (int32_t)(vrem > vremainder_threshold);
215
216 /* Clamp and add output zero point */
217 int32_t vy = vacc + vy_zero_point;
218 vy = vy >= vy_min ? vy : vy_min;
219 vy = vy <= vy_max ? vy : vy_max;
220
221 *y++ = (uint8_t)vy;
222 }
223 }
224 }
225