xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <immintrin.h>
10 
11 #include <qnnpack/common.h>
12 #include <qnnpack/q8vadd.h>
13 #include <qnnpack/scalar-utils.h>
14 
pytorch_q8vadd_ukernel__sse2(size_t n,const uint8_t * a,const uint8_t * b,uint8_t * y,const union pytorch_qnnp_add_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8vadd_ukernel__sse2(
16     size_t n,
17     const uint8_t* a,
18     const uint8_t* b,
19     uint8_t* y,
20     const union pytorch_qnnp_add_quantization_params
21         quantization_params[RESTRICT_STATIC 1]) {
22   if
23     PYTORCH_QNNP_LIKELY(n >= 8) {
24       const __m128i vzero_point_product = _mm_load_si128(
25           (const __m128i*)&quantization_params->sse2.zero_point_product);
26       const __m128i va_multiplier_lo = _mm_load_si128(
27           (const __m128i*)&quantization_params->sse2.a_multiplier_lo);
28       const __m128i va_multiplier_hi = _mm_load_si128(
29           (const __m128i*)&quantization_params->sse2.a_multiplier_hi);
30       const __m128i vb_multiplier_lo = _mm_load_si128(
31           (const __m128i*)&quantization_params->sse2.b_multiplier_lo);
32       const __m128i vb_multiplier_hi = _mm_load_si128(
33           (const __m128i*)&quantization_params->sse2.b_multiplier_hi);
34       const __m128i vremainder_mask = _mm_load_si128(
35           (const __m128i*)quantization_params->sse2.remainder_mask);
36       const __m128i vremainder_threshold = _mm_load_si128(
37           (const __m128i*)quantization_params->sse2.remainder_threshold);
38       const __m128i vshift =
39           _mm_cvtsi32_si128((int)quantization_params->sse2.shift);
40 
41       const __m128i vzero = _mm_setzero_si128();
42       do {
43         const __m128i va = _mm_loadl_epi64((const __m128i*)a);
44         a += 8;
45         const __m128i vb = _mm_loadl_epi64((const __m128i*)b);
46         b += 8;
47 
48         const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
49         const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
50 
51         /* Multiply by factors */
52         const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
53         const __m128i va_product_hi = _mm_add_epi16(
54             _mm_mulhi_epu16(vxa, va_multiplier_lo),
55             _mm_mullo_epi16(vxa, va_multiplier_hi));
56 
57         const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
58         const __m128i vb_product_hi = _mm_add_epi16(
59             _mm_mulhi_epu16(vxb, vb_multiplier_lo),
60             _mm_mullo_epi16(vxb, vb_multiplier_hi));
61 
62         /* Accumulate products */
63         __m128i vacc_lo = _mm_add_epi32(
64             vzero_point_product,
65             _mm_unpacklo_epi16(va_product_lo, va_product_hi));
66         __m128i vacc_hi = _mm_add_epi32(
67             vzero_point_product,
68             _mm_unpackhi_epi16(va_product_lo, va_product_hi));
69 
70         vacc_lo = _mm_add_epi32(
71             vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
72         vacc_hi = _mm_add_epi32(
73             vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
74 
75         /* Shift right and round */
76         const __m128i vrem_lo = _mm_add_epi32(
77             _mm_and_si128(vacc_lo, vremainder_mask),
78             _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
79         const __m128i vrem_hi = _mm_add_epi32(
80             _mm_and_si128(vacc_hi, vremainder_mask),
81             _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
82 
83         vacc_lo = _mm_sub_epi32(
84             _mm_sra_epi32(vacc_lo, vshift),
85             _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
86         vacc_hi = _mm_sub_epi32(
87             _mm_sra_epi32(vacc_hi, vshift),
88             _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
89 
90         /* Pack, saturate, and add output zero point */
91         const __m128i vy_zero_point = _mm_load_si128(
92             (const __m128i*)quantization_params->sse2.y_zero_point);
93         const __m128i vacc =
94             _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
95         __m128i vy = _mm_packus_epi16(vacc, vacc);
96         vy = _mm_max_epu8(
97             vy,
98             _mm_load_si128((const __m128i*)quantization_params->sse2.y_min));
99         vy = _mm_min_epu8(
100             vy,
101             _mm_load_si128((const __m128i*)quantization_params->sse2.y_max));
102 
103         _mm_storel_epi64((__m128i*)y, vy);
104         y += 8;
105 
106         n -= 8;
107       } while (n >= 8);
108       if (n != 0) {
109         const size_t n_decrement = 8 - n;
110         const __m128i vload_shift = _mm_cvtsi32_si128(8 * (int32_t)n_decrement);
111 
112         const __m128i va = _mm_srl_epi64(
113             _mm_loadl_epi64((const __m128i*)(a - n_decrement)), vload_shift);
114         const __m128i vb = _mm_srl_epi64(
115             _mm_loadl_epi64((const __m128i*)(b - n_decrement)), vload_shift);
116 
117         const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
118         const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
119 
120         /* Multiply by factors */
121         const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
122         const __m128i va_product_hi = _mm_add_epi16(
123             _mm_mulhi_epu16(vxa, va_multiplier_lo),
124             _mm_mullo_epi16(vxa, va_multiplier_hi));
125 
126         const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
127         const __m128i vb_product_hi = _mm_add_epi16(
128             _mm_mulhi_epu16(vxb, vb_multiplier_lo),
129             _mm_mullo_epi16(vxb, vb_multiplier_hi));
130 
131         /* Accumulate products */
132         __m128i vacc_lo = _mm_add_epi32(
133             vzero_point_product,
134             _mm_unpacklo_epi16(va_product_lo, va_product_hi));
135         __m128i vacc_hi = _mm_add_epi32(
136             vzero_point_product,
137             _mm_unpackhi_epi16(va_product_lo, va_product_hi));
138 
139         vacc_lo = _mm_add_epi32(
140             vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
141         vacc_hi = _mm_add_epi32(
142             vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
143 
144         /* Shift right and round */
145         const __m128i vrem_lo = _mm_add_epi32(
146             _mm_and_si128(vacc_lo, vremainder_mask),
147             _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
148         const __m128i vrem_hi = _mm_add_epi32(
149             _mm_and_si128(vacc_hi, vremainder_mask),
150             _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
151 
152         vacc_lo = _mm_sub_epi32(
153             _mm_sra_epi32(vacc_lo, vshift),
154             _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
155         vacc_hi = _mm_sub_epi32(
156             _mm_sra_epi32(vacc_hi, vshift),
157             _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
158 
159         /* Pack, saturate, and add output zero point */
160         const __m128i vy_zero_point = _mm_load_si128(
161             (const __m128i*)quantization_params->sse2.y_zero_point);
162         const __m128i vacc =
163             _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
164         __m128i vy = _mm_packus_epi16(vacc, vacc);
165         vy = _mm_max_epu8(
166             vy,
167             _mm_load_si128((const __m128i*)quantization_params->sse2.y_min));
168         vy = _mm_min_epu8(
169             vy,
170             _mm_load_si128((const __m128i*)quantization_params->sse2.y_max));
171 
172         if (n & 4) {
173           *((uint32_t*)y) = (uint32_t)_mm_cvtsi128_si32(vy);
174           vy = _mm_shuffle_epi32(vy, _MM_SHUFFLE(3, 2, 1, 1));
175           y += 4;
176         }
177         if (n & 2) {
178           *((uint16_t*)y) = (uint16_t)_mm_extract_epi16(vy, 0);
179           vy = _mm_srli_epi32(vy, 16);
180           y += 2;
181         }
182         if (n & 1) {
183           *((uint8_t*)y) = (uint8_t)_mm_cvtsi128_si32(vy);
184         }
185       }
186     }
187   else {
188     const int32_t vzero_point_product =
189         quantization_params->sse2.zero_point_product[0];
190     const uint32_t va_multiplier = quantization_params->sse2.a_multiplier;
191     const uint32_t vb_multiplier = quantization_params->sse2.b_multiplier;
192     const int32_t vremainder_mask = quantization_params->sse2.remainder_mask[0];
193     const int32_t vremainder_threshold =
194         quantization_params->sse2.remainder_threshold[0];
195     const uint32_t vshift = quantization_params->sse2.shift;
196     const int32_t vy_zero_point =
197         (int32_t)quantization_params->sse2.y_zero_point[0];
198     const int32_t vy_max =
199         (int32_t)(uint32_t)quantization_params->sse2.y_max[0];
200     const int32_t vy_min =
201         (int32_t)(uint32_t)quantization_params->sse2.y_min[0];
202 
203     while (n-- != 0) {
204       const uint32_t vxa = (uint32_t)*a++;
205       const uint32_t vxb = (uint32_t)*b++;
206 
207       /* Multiply by factors and accumulate products */
208       int32_t vacc = vzero_point_product + (int32_t)(vxa * va_multiplier) +
209           (int32_t)(vxb * vb_multiplier);
210 
211       /* Shift right and round */
212       const int32_t vrem = (vacc & vremainder_mask) - (int32_t)(vacc < 0);
213 
214       vacc = asr_s32(vacc, vshift) + (int32_t)(vrem > vremainder_threshold);
215 
216       /* Clamp and add output zero point */
217       int32_t vy = vacc + vy_zero_point;
218       vy = vy >= vy_min ? vy : vy_min;
219       vy = vy <= vy_max ? vy : vy_max;
220 
221       *y++ = (uint8_t)vy;
222     }
223   }
224 }
225