xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <stdbool.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 
13 #ifdef _MSC_VER
14 #include <windows.h>
15 #else
16 #include <pthread.h>
17 #endif
18 
19 #include <cpuinfo.h>
20 #include <pytorch_qnnpack.h>
21 #include <qnnpack/log.h>
22 #include <qnnpack/params.h>
23 #include <qnnpack/q8avgpool.h>
24 #include <qnnpack/q8conv.h>
25 #include <qnnpack/q8dwconv.h>
26 #include <qnnpack/q8gavgpool.h>
27 #include <qnnpack/q8gemm.h>
28 #include <qnnpack/q8gemm_sparse.h>
29 #include <qnnpack/q8vadd.h>
30 #include <qnnpack/u8clamp.h>
31 #include <qnnpack/u8lut32norm.h>
32 #include <qnnpack/u8maxpool.h>
33 #include <qnnpack/u8rmax.h>
34 #include <qnnpack/x8lut.h>
35 #include <qnnpack/x8zip.h>
36 
37 
38 #ifdef _MSC_VER
39 static INIT_ONCE init_guard;
40 BOOL CALLBACK pytorch_qnnp_init_win(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* lpContex);
41 #else
42 static pthread_once_t init_guard = PTHREAD_ONCE_INIT;
43 #endif
44 
45 struct pytorch_qnnp_parameters pytorch_qnnp_params = {.initialized = false};
46 
init(void)47 static void init(void) {
48 #if CPUINFO_ARCH_ARM
49   if (!cpuinfo_has_arm_neon()) {
50     pytorch_qnnp_log_error(
51         "QNNPACK initialization failed: NEON is not supported");
52     return;
53   }
54   pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){
55       .gemm = pytorch_q8gemm_ukernel_4x8__aarch32_neon,
56       .conv = pytorch_q8conv_ukernel_4x8__aarch32_neon,
57       .gemm_dq = pytorch_q8gemm_dq_ukernel_4x8__aarch32_neon,
58       .mr = 4,
59       .nr = 8,
60       .kr = 1,
61   };
62   pytorch_qnnp_params.q8gemm_sparse_c1x4 = (struct pytorch_q8gemm_sparse_parameters){
63       .gemm_dq = NULL,
64       .packedA_w32_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA_w32__aarch32_neon,
65       .packedA_w16_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA_w16__aarch32_neon,
66       .packedA_w8_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA_w8__aarch32_neon,
67       .packA = pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon,
68       .mr = 4,
69       .nr = 8,
70       .kr = 4,
71       .log2_mr = 2,
72       .log2_row_block_size = 0,
73       .row_block_size = 1,
74       .col_block_size = 4,
75   };
76   pytorch_qnnp_params.q8gemm_sparse_c8x1 = (struct pytorch_q8gemm_sparse_parameters){
77       .gemm_dq = NULL,
78       .packedA_w32_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_4x8_packedA_w32__aarch32_neon,
79       .packedA_w16_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_4x8_packedA_w16__aarch32_neon,
80       .packedA_w8_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_4x8_packedA_w8__aarch32_neon,
81       .packA = pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon,
82       .mr = 4,
83       .nr = 8,
84       .kr = 4, // kr is really 1 but we set it to 4 because we reuse 4x4 prepacking kernel
85       .log2_mr = 2,
86       .log2_row_block_size = 3,
87       .row_block_size = 8,
88       .col_block_size = 1,
89   };
90 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
91   pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
92       .gemm = pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon,
93       .mr = 4,
94       .nr = 8,
95       .kr = 2,
96       .kc = 8,
97       .kthreshold = SIZE_MAX,
98   };
99   /* setup xzp threshold based on measurements */
100   switch (cpuinfo_get_core(0)->uarch) {
101     case cpuinfo_uarch_cortex_a72:
102       pytorch_qnnp_params.q8conv_xzp.kthreshold = 64;
103       break;
104     case cpuinfo_uarch_cortex_a73:
105       pytorch_qnnp_params.q8conv_xzp.kthreshold = 256;
106       break;
107     case cpuinfo_uarch_cortex_a75:
108       pytorch_qnnp_params.q8conv_xzp.kthreshold = 32;
109       break;
110     case cpuinfo_uarch_cortex_a76:
111       pytorch_qnnp_params.q8conv_xzp.kthreshold = 16;
112       break;
113     default:
114       break;
115   }
116 #else
117   pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
118       .kthreshold = SIZE_MAX,
119   };
120 #endif
121   pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv2d_up_parameters){
122       .updw = pytorch_q8dwconv_ukernel_up8x9__aarch32_neon,
123       .updw_per_channel =
124           pytorch_q8dwconv_ukernel_up8x9_per_channel__aarch32_neon,
125       .cr = 8,
126   };
127   pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv2d_mp_parameters){
128       .mpdw = pytorch_q8dwconv_ukernel_mp8x25__neon,
129       .mpdw_per_channel = pytorch_q8dwconv_ukernel_mp8x25_per_channel__neon,
130       .cr = 8,
131   };
132   pytorch_qnnp_params.q8dw27 = (struct pytorch_q8dwconv3d_mp_parameters){
133       .mpdw = pytorch_q8dwconv_ukernel_mp8x27__neon,
134       .cr = 8,
135   };
136   pytorch_qnnp_params.q8sum_rows = (struct pytorch_q8sum_rows_parameters){
137       .sum_rows = pytorch_q8sumrows_ukernel_4x__neon,
138       .m = 4,
139   };
140   pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__neon;
141   pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){
142       .ltnr = pytorch_q8gavgpool_ukernel_up8xm__neon,
143       .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__neon,
144       .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__neon,
145       .mr = 7,
146       .nr = 8,
147   };
148   pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){
149       .ltkr = pytorch_q8avgpool_ukernel_up8xm__neon,
150       .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__neon,
151       .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__neon,
152       .mr = 9,
153       .qr = 8,
154       .kr = 8,
155   };
156   pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){
157       .ltkr = pytorch_u8maxpool_ukernel_sub16__neon,
158       .gekr = pytorch_u8maxpool_ukernel_16x9p8q__neon,
159       .mr = 9,
160       .qr = 8,
161       .kr = 16,
162   };
163   pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){
164       .x2 = pytorch_qnnp_x8zip_x2__neon,
165       .x3 = pytorch_qnnp_x8zip_x3__neon,
166       .x4 = pytorch_qnnp_x8zip_x4__neon,
167       .xm = pytorch_qnnp_x8zip_xm__neon,
168   };
169   pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__neon;
170   pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__neon;
171   pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar;
172   pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar;
173 #elif CPUINFO_ARCH_ARM64
174   pytorch_qnnp_params.q8gemm_sparse_c1x4 = (struct pytorch_q8gemm_sparse_parameters){
175       .gemm_dq = NULL,
176       .packedA_w32_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon,
177       .packedA_w16_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w16__aarch64_neon,
178       .packedA_w8_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w8__aarch64_neon,
179       .packA = pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon,
180       .mr = 8,
181       .nr = 8,
182       .kr = 4,
183       .log2_mr = 3,
184       .log2_row_block_size = 0,
185       .row_block_size = 1,
186       .col_block_size = 4,
187   };
188   pytorch_qnnp_params.q8gemm_sparse_c8x1 = (struct pytorch_q8gemm_sparse_parameters){
189       .gemm_dq = NULL,
190       .packedA_w32_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon,
191       .packedA_w16_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w16__aarch64_neon,
192       .packedA_w8_gemm_dq = pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w8__aarch64_neon,
193       .packA = pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon,
194       .mr = 8,
195       .nr = 8,
196       .kr = 4, // kr is really 1 but we set it to 4 because we reuse 4x4 prepacking kernel
197       .log2_mr = 3,
198       .log2_row_block_size = 3,
199       .row_block_size = 8,
200       .col_block_size = 1,
201   };
202   pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){
203       .gemm = pytorch_q8gemm_ukernel_8x8__aarch64_neon,
204       .conv = pytorch_q8conv_ukernel_8x8__aarch64_neon,
205       .gemm_dq = pytorch_q8gemm_dq_ukernel_8x8__aarch64_neon,
206       .mr = 8,
207       .nr = 8,
208       .kr = 1,
209   };
210   pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
211       .kthreshold = SIZE_MAX,
212   };
213   pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv2d_up_parameters){
214       .updw = pytorch_q8dwconv_ukernel_up8x9__neon,
215       .updw_per_channel = pytorch_q8dwconv_ukernel_up8x9_per_channel__neon,
216       .cr = 8,
217   };
218   pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv2d_mp_parameters){
219       .mpdw = pytorch_q8dwconv_ukernel_mp8x25__neon,
220       .mpdw_per_channel = pytorch_q8dwconv_ukernel_mp8x25_per_channel__neon,
221       .cr = 8,
222   };
223   pytorch_qnnp_params.q8dw27 = (struct pytorch_q8dwconv3d_mp_parameters){
224       .mpdw = pytorch_q8dwconv_ukernel_mp8x27__neon,
225       .cr = 8,
226   };
227   pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__neon;
228   pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){
229       .ltnr = pytorch_q8gavgpool_ukernel_up8xm__neon,
230       .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__neon,
231       .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__neon,
232       .mr = 7,
233       .nr = 8,
234   };
235   pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){
236       .ltkr = pytorch_q8avgpool_ukernel_up8xm__neon,
237       .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__neon,
238       .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__neon,
239       .mr = 9,
240       .qr = 8,
241       .kr = 8,
242   };
243   pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){
244       .ltkr = pytorch_u8maxpool_ukernel_sub16__neon,
245       .gekr = pytorch_u8maxpool_ukernel_16x9p8q__neon,
246       .mr = 9,
247       .qr = 8,
248       .kr = 16,
249   };
250   pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){
251       .x2 = pytorch_qnnp_x8zip_x2__neon,
252       .x3 = pytorch_qnnp_x8zip_x3__neon,
253       .x4 = pytorch_qnnp_x8zip_x4__neon,
254       .xm = pytorch_qnnp_x8zip_xm__neon,
255   };
256   pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__neon;
257   pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__neon;
258   pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar;
259   pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar;
260 #elif CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
261   if (!cpuinfo_has_x86_sse2()) {
262     pytorch_qnnp_log_error(
263         "QNNPACK initialization failed: SSE2 is not supported");
264     return;
265   }
266   pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){
267       .gemm = pytorch_q8gemm_ukernel_4x4c2__sse2,
268       .conv = pytorch_q8conv_ukernel_4x4c2__sse2,
269       .gemm_dq = pytorch_q8gemm_dq_ukernel_4x4c2__sse2,
270       .mr = 4,
271       .nr = 4,
272       .kr = 2,
273   };
274   pytorch_qnnp_params.q8gemm_sparse_c1x4 = (struct pytorch_q8gemm_sparse_parameters){
275       .gemm_dq = NULL,
276       .packedA_w32_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA_w32__sse2,
277       .packedA_w16_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA_w16__sse2,
278       .packedA_w8_gemm_dq = pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA_w8__sse2,
279       .packA = pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2,
280       .mr = 8,
281       .nr = 4,
282       .kr = 4,
283       .log2_mr = 3,
284       .log2_row_block_size = 0,
285       .row_block_size = 1,
286       .col_block_size = 4,
287   };
288   pytorch_qnnp_params.q8gemm_sparse_c8x1 = (struct pytorch_q8gemm_sparse_parameters){
289       .gemm_dq = NULL,
290       .packedA_w32_gemm_dq = NULL,
291       .packedA_w16_gemm_dq = NULL,
292       .packedA_w8_gemm_dq = NULL,
293       .packA = NULL,
294       .mr = 4,
295       .nr = 8,
296       .kr = 1,
297       .log2_mr = 2,
298       .log2_row_block_size = 3,
299       .row_block_size = 8,
300       .col_block_size = 1,
301   };
302   pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){
303       .kthreshold = SIZE_MAX,
304   };
305   pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv2d_up_parameters){
306       .updw = pytorch_q8dwconv_ukernel_up8x9__sse2,
307       .updw_per_channel = pytorch_q8dwconv_ukernel_up8x9_per_channel__sse2,
308       .cr = 8,
309   };
310   pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv2d_mp_parameters){
311       .mpdw = pytorch_q8dwconv_ukernel_mp8x25__sse2,
312       .mpdw_per_channel = pytorch_q8dwconv_ukernel_mp8x25_per_channel__sse2,
313       .cr = 8,
314   };
315   pytorch_qnnp_params.q8dw27 = (struct pytorch_q8dwconv3d_mp_parameters){
316       .mpdw = pytorch_q8dwconv_ukernel_mp8x27__sse2,
317       .cr = 8,
318   };
319   pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__sse2;
320   pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){
321       .ltnr = pytorch_q8gavgpool_ukernel_up8xm__sse2,
322       .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__sse2,
323       .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2,
324       .mr = 7,
325       .nr = 8,
326   };
327   pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){
328       .ltkr = pytorch_q8avgpool_ukernel_up8xm__sse2,
329       .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__sse2,
330       .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__sse2,
331       .mr = 9,
332       .qr = 8,
333       .kr = 8,
334   };
335   pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){
336       .ltkr = pytorch_u8maxpool_ukernel_sub16__sse2,
337       .gekr = pytorch_u8maxpool_ukernel_16x9p8q__sse2,
338       .mr = 9,
339       .qr = 8,
340       .kr = 16,
341   };
342   pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){
343       .x2 = pytorch_qnnp_x8zip_x2__sse2,
344       .x3 = pytorch_qnnp_x8zip_x3__sse2,
345       .x4 = pytorch_qnnp_x8zip_x4__sse2,
346       .xm = pytorch_qnnp_x8zip_xm__sse2,
347   };
348   pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__sse2;
349   pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__sse2;
350   pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar;
351   pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar;
352 #else
353 #error "Unsupported architecture"
354 #endif
355   pytorch_qnnp_params.initialized = true;
356 }
357 
pytorch_qnnp_initialize(void)358 enum pytorch_qnnp_status pytorch_qnnp_initialize(void) {
359   if (!cpuinfo_initialize()) {
360     return pytorch_qnnp_status_out_of_memory;
361   }
362 #ifdef _MSC_VER
363   InitOnceExecuteOnce(&init_guard, pytorch_qnnp_init_win, NULL, NULL);
364 #else
365   pthread_once(&init_guard, &init);
366 #endif
367   if (pytorch_qnnp_params.initialized) {
368     return pytorch_qnnp_status_success;
369   } else {
370     return pytorch_qnnp_status_unsupported_hardware;
371   }
372 }
373 
pytorch_qnnp_deinitialize(void)374 enum pytorch_qnnp_status pytorch_qnnp_deinitialize(void) {
375   cpuinfo_deinitialize();
376   return pytorch_qnnp_status_success;
377 }
378 
379 #ifdef _MSC_VER
pytorch_qnnp_init_win(PINIT_ONCE InitOnce,PVOID Parameter,PVOID * lpContex)380 BOOL CALLBACK pytorch_qnnp_init_win(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* lpContex) {
381   init();
382   return TRUE;
383 }
384 #endif
385