xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-prepack.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pytorch_qnnpack.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/operator.h>
4 #include <qnnpack/pack.h>
5 #include <qnnpack_func.h>
6 #include <cstring>
7 
8 namespace qnnpack {
9 
PrePackConvWeights(const pytorch_qnnp_operator_t convolution,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias)10 PrePackConvWeights::PrePackConvWeights(
11     const pytorch_qnnp_operator_t convolution,
12     const uint8_t* kernel_zero_points,
13     const uint8_t* kernel,
14     const int32_t* bias) {
15   enum pytorch_qnnp_ukernel_type ukernel_type = convolution->ukernel_type;
16   const uint32_t kernel_width = convolution->kernel_width;
17   const uint32_t kernel_height = convolution->kernel_height;
18   // deconvolution leaves this 0 for now, remove when deconvolution supports 3d
19   const uint32_t kernel_depth =
20       convolution->kernel_depth ? convolution->kernel_depth : 1;
21   const uint32_t groups = convolution->groups;
22 
23   if (convolution->transpose &&
24       ukernel_type != pytorch_qnnp_ukernel_type_conv) {
25     pytorch_qnnp_log_error("Wrong micro-kernel for deconvolution");
26     assert(false && "QNNPACK Runtime Error.");
27   }
28 
29   const size_t kernel_size = kernel_height * kernel_width * kernel_depth;
30   switch (ukernel_type) {
31     case pytorch_qnnp_ukernel_type_dwconv: {
32       const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
33       const uint32_t c_stride = (groups + (cr - 1)) & -cr;
34       const size_t packed_weights_size =
35           (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride;
36       packed_weights_ = malloc(packed_weights_size);
37       if (packed_weights_ == nullptr) {
38         pytorch_qnnp_log_error(
39             "failed to allocate %zu bytes for packed weights",
40             packed_weights_size);
41         assert(false && "QNNPACK Runtime Error.");
42       }
43 
44       switch (kernel_size) {
45         case 9:
46           pytorch_pack_q8dw_wrq(
47               kernel_height,
48               kernel_width,
49               groups,
50               cr,
51               kernel,
52               bias,
53               packed_weights_);
54           break;
55         case 25:
56           /* change this later */
57           pytorch_pack_q8dw_2d_w_dilation(
58               kernel_height,
59               kernel_width,
60               groups,
61               cr,
62               0,
63               kernel_height,
64               0,
65               2,
66               kernel,
67               bias,
68               packed_weights_,
69               true);
70           pytorch_pack_q8dw_2d_w_dilation(
71               kernel_height,
72               kernel_width,
73               groups,
74               cr,
75               0,
76               kernel_height,
77               2,
78               4,
79               kernel,
80               bias,
81               (char*)packed_weights_ +
82                   (10 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
83               false);
84           pytorch_pack_q8dw_2d_w_dilation(
85               kernel_height,
86               kernel_width,
87               groups,
88               cr,
89               0,
90               kernel_height,
91               4,
92               5,
93               kernel,
94               bias,
95               (char*)packed_weights_ +
96                   (20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
97               false);
98           break;
99         case 27:
100           pytorch_pack_q8dw_3d_w_dilation(
101               kernel_depth,
102               kernel_height,
103               kernel_width,
104               groups,
105               cr,
106               0,
107               kernel_depth,
108               0,
109               kernel_height,
110               0,
111               1,
112               kernel,
113               bias,
114               packed_weights_,
115               true);
116           pytorch_pack_q8dw_3d_w_dilation(
117               kernel_depth,
118               kernel_height,
119               kernel_width,
120               groups,
121               cr,
122               0,
123               kernel_depth,
124               0,
125               kernel_height,
126               1,
127               2,
128               kernel,
129               bias,
130               (char*)packed_weights_ +
131                   (kernel_depth * kernel_height +
132                    sizeof(int32_t) / sizeof(uint8_t)) *
133                       c_stride,
134               false);
135           pytorch_pack_q8dw_3d_w_dilation(
136               kernel_depth,
137               kernel_height,
138               kernel_width,
139               groups,
140               cr,
141               0,
142               kernel_depth,
143               0,
144               kernel_height,
145               2,
146               3,
147               kernel,
148               bias,
149               (char*)packed_weights_ +
150                   (2 * kernel_depth * kernel_height +
151                    sizeof(int32_t) / sizeof(uint8_t)) *
152                       c_stride,
153               false);
154           break;
155         default:
156           PYTORCH_QNNP_UNREACHABLE;
157       }
158       break;
159     }
160     case pytorch_qnnp_ukernel_type_xzp_gemm: {
161       const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
162       const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
163       const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc;
164       const uint32_t n_stride =
165           (convolution->group_output_channels + (nr - 1)) & -nr;
166       const uint32_t k_stride =
167           (convolution->group_input_channels + (kr - 1)) & -kr;
168 
169       const size_t packed_group_weights_size =
170           (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
171           n_stride;
172       packed_weights_ = malloc(packed_group_weights_size * groups);
173       if (packed_weights_ == nullptr) {
174         pytorch_qnnp_log_error(
175             "failed to allocate %zu bytes for packed weights",
176             packed_group_weights_size * groups);
177         assert(false && "QNNPACK Runtime Error.");
178       }
179       /* The XZP ukernel needs the padding to be 0 */
180       memset(packed_weights_, 0, packed_group_weights_size * groups);
181 
182       for (uint32_t group = 0; group < groups; group++) {
183         pytorch_pack_swizzle_q8gemm_brq(
184             convolution->group_output_channels,
185             convolution->group_input_channels,
186             nr,
187             kr,
188             sr,
189             kernel +
190                 group * convolution->group_output_channels *
191                     convolution->group_input_channels,
192             bias + group * convolution->group_output_channels,
193             (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
194       }
195       break;
196     }
197     case pytorch_qnnp_ukernel_type_gemm:
198     case pytorch_qnnp_ukernel_type_conv: {
199       const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
200       const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
201       const uint32_t n_stride =
202           (convolution->group_output_channels + (nr - 1)) & -nr;
203       const uint32_t k_stride =
204           (convolution->group_input_channels + (kr - 1)) & -kr;
205 
206       const size_t packed_group_weights_size =
207           (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
208           n_stride;
209       packed_weights_ = malloc(packed_group_weights_size * groups);
210       if (packed_weights_ == nullptr) {
211         pytorch_qnnp_log_error(
212             "failed to allocate %zu bytes for packed weights",
213             packed_group_weights_size * groups);
214         assert(false && "QNNPACK Runtime Error.");
215       }
216       // We likely won't needs this once packing functions are appropriately
217       // modified. Remove it then.
218       memset(
219           packed_weights_,
220           kernel_zero_points[0],
221           packed_group_weights_size * groups);
222 
223       switch (ukernel_type) {
224         case pytorch_qnnp_ukernel_type_gemm:
225           for (uint32_t group = 0; group < groups; group++) {
226             pytorch_pack_q8gemm_wrq(
227                 convolution->group_output_channels,
228                 convolution->group_input_channels,
229                 nr,
230                 nr,
231                 kr,
232                 kernel +
233                     group * convolution->group_output_channels *
234                         convolution->group_input_channels,
235                 bias + group * convolution->group_output_channels,
236                 kernel_zero_points + group * convolution->group_output_channels,
237                 (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
238           }
239           break;
240         case pytorch_qnnp_ukernel_type_conv:  // The transpose can only be here
241           for (uint32_t group = 0; group < groups; group++) {
242             const uint8_t* const kernel_p = kernel +
243                 group * convolution->group_output_channels * kernel_size *
244                     convolution->group_input_channels;
245             const int32_t* const bias_p =
246                 bias + group * convolution->group_output_channels;
247             if (convolution
248                     ->transpose) { // Note that only runtime packing is here
249               pytorch_pack_q8deconv_wrq(
250                   convolution->group_output_channels,
251                   kernel_size,
252                   convolution->group_input_channels,
253                   nr,
254                   kr,
255                   kernel_p,
256                   bias_p,
257                   kernel_zero_points +
258                       group * convolution->group_output_channels,
259                   (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
260             } else {
261               pytorch_pack_q8conv_wrq(
262                   convolution->group_output_channels,
263                   kernel_size,
264                   convolution->group_input_channels,
265                   nr,
266                   kr,
267                   kernel_p,
268                   bias_p,
269                   kernel_zero_points +
270                       group * convolution->group_output_channels,
271                   (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
272             }
273           }
274           break;
275         default:
276           PYTORCH_QNNP_UNREACHABLE;
277       }
278       break;
279     }
280     default:
281       PYTORCH_QNNP_UNREACHABLE;
282   }
283 } // namespace qnnpack
284 } // namespace qnnpack
285