1 #include <pytorch_qnnpack.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/operator.h>
4 #include <qnnpack/pack.h>
5 #include <qnnpack_func.h>
6 #include <cstring>
7
8 namespace qnnpack {
9
PrePackConvWeights(const pytorch_qnnp_operator_t convolution,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias)10 PrePackConvWeights::PrePackConvWeights(
11 const pytorch_qnnp_operator_t convolution,
12 const uint8_t* kernel_zero_points,
13 const uint8_t* kernel,
14 const int32_t* bias) {
15 enum pytorch_qnnp_ukernel_type ukernel_type = convolution->ukernel_type;
16 const uint32_t kernel_width = convolution->kernel_width;
17 const uint32_t kernel_height = convolution->kernel_height;
18 // deconvolution leaves this 0 for now, remove when deconvolution supports 3d
19 const uint32_t kernel_depth =
20 convolution->kernel_depth ? convolution->kernel_depth : 1;
21 const uint32_t groups = convolution->groups;
22
23 if (convolution->transpose &&
24 ukernel_type != pytorch_qnnp_ukernel_type_conv) {
25 pytorch_qnnp_log_error("Wrong micro-kernel for deconvolution");
26 assert(false && "QNNPACK Runtime Error.");
27 }
28
29 const size_t kernel_size = kernel_height * kernel_width * kernel_depth;
30 switch (ukernel_type) {
31 case pytorch_qnnp_ukernel_type_dwconv: {
32 const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
33 const uint32_t c_stride = (groups + (cr - 1)) & -cr;
34 const size_t packed_weights_size =
35 (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride;
36 packed_weights_ = malloc(packed_weights_size);
37 if (packed_weights_ == nullptr) {
38 pytorch_qnnp_log_error(
39 "failed to allocate %zu bytes for packed weights",
40 packed_weights_size);
41 assert(false && "QNNPACK Runtime Error.");
42 }
43
44 switch (kernel_size) {
45 case 9:
46 pytorch_pack_q8dw_wrq(
47 kernel_height,
48 kernel_width,
49 groups,
50 cr,
51 kernel,
52 bias,
53 packed_weights_);
54 break;
55 case 25:
56 /* change this later */
57 pytorch_pack_q8dw_2d_w_dilation(
58 kernel_height,
59 kernel_width,
60 groups,
61 cr,
62 0,
63 kernel_height,
64 0,
65 2,
66 kernel,
67 bias,
68 packed_weights_,
69 true);
70 pytorch_pack_q8dw_2d_w_dilation(
71 kernel_height,
72 kernel_width,
73 groups,
74 cr,
75 0,
76 kernel_height,
77 2,
78 4,
79 kernel,
80 bias,
81 (char*)packed_weights_ +
82 (10 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
83 false);
84 pytorch_pack_q8dw_2d_w_dilation(
85 kernel_height,
86 kernel_width,
87 groups,
88 cr,
89 0,
90 kernel_height,
91 4,
92 5,
93 kernel,
94 bias,
95 (char*)packed_weights_ +
96 (20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
97 false);
98 break;
99 case 27:
100 pytorch_pack_q8dw_3d_w_dilation(
101 kernel_depth,
102 kernel_height,
103 kernel_width,
104 groups,
105 cr,
106 0,
107 kernel_depth,
108 0,
109 kernel_height,
110 0,
111 1,
112 kernel,
113 bias,
114 packed_weights_,
115 true);
116 pytorch_pack_q8dw_3d_w_dilation(
117 kernel_depth,
118 kernel_height,
119 kernel_width,
120 groups,
121 cr,
122 0,
123 kernel_depth,
124 0,
125 kernel_height,
126 1,
127 2,
128 kernel,
129 bias,
130 (char*)packed_weights_ +
131 (kernel_depth * kernel_height +
132 sizeof(int32_t) / sizeof(uint8_t)) *
133 c_stride,
134 false);
135 pytorch_pack_q8dw_3d_w_dilation(
136 kernel_depth,
137 kernel_height,
138 kernel_width,
139 groups,
140 cr,
141 0,
142 kernel_depth,
143 0,
144 kernel_height,
145 2,
146 3,
147 kernel,
148 bias,
149 (char*)packed_weights_ +
150 (2 * kernel_depth * kernel_height +
151 sizeof(int32_t) / sizeof(uint8_t)) *
152 c_stride,
153 false);
154 break;
155 default:
156 PYTORCH_QNNP_UNREACHABLE;
157 }
158 break;
159 }
160 case pytorch_qnnp_ukernel_type_xzp_gemm: {
161 const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr;
162 const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr;
163 const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc;
164 const uint32_t n_stride =
165 (convolution->group_output_channels + (nr - 1)) & -nr;
166 const uint32_t k_stride =
167 (convolution->group_input_channels + (kr - 1)) & -kr;
168
169 const size_t packed_group_weights_size =
170 (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
171 n_stride;
172 packed_weights_ = malloc(packed_group_weights_size * groups);
173 if (packed_weights_ == nullptr) {
174 pytorch_qnnp_log_error(
175 "failed to allocate %zu bytes for packed weights",
176 packed_group_weights_size * groups);
177 assert(false && "QNNPACK Runtime Error.");
178 }
179 /* The XZP ukernel needs the padding to be 0 */
180 memset(packed_weights_, 0, packed_group_weights_size * groups);
181
182 for (uint32_t group = 0; group < groups; group++) {
183 pytorch_pack_swizzle_q8gemm_brq(
184 convolution->group_output_channels,
185 convolution->group_input_channels,
186 nr,
187 kr,
188 sr,
189 kernel +
190 group * convolution->group_output_channels *
191 convolution->group_input_channels,
192 bias + group * convolution->group_output_channels,
193 (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
194 }
195 break;
196 }
197 case pytorch_qnnp_ukernel_type_gemm:
198 case pytorch_qnnp_ukernel_type_conv: {
199 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
200 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
201 const uint32_t n_stride =
202 (convolution->group_output_channels + (nr - 1)) & -nr;
203 const uint32_t k_stride =
204 (convolution->group_input_channels + (kr - 1)) & -kr;
205
206 const size_t packed_group_weights_size =
207 (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) *
208 n_stride;
209 packed_weights_ = malloc(packed_group_weights_size * groups);
210 if (packed_weights_ == nullptr) {
211 pytorch_qnnp_log_error(
212 "failed to allocate %zu bytes for packed weights",
213 packed_group_weights_size * groups);
214 assert(false && "QNNPACK Runtime Error.");
215 }
216 // We likely won't needs this once packing functions are appropriately
217 // modified. Remove it then.
218 memset(
219 packed_weights_,
220 kernel_zero_points[0],
221 packed_group_weights_size * groups);
222
223 switch (ukernel_type) {
224 case pytorch_qnnp_ukernel_type_gemm:
225 for (uint32_t group = 0; group < groups; group++) {
226 pytorch_pack_q8gemm_wrq(
227 convolution->group_output_channels,
228 convolution->group_input_channels,
229 nr,
230 nr,
231 kr,
232 kernel +
233 group * convolution->group_output_channels *
234 convolution->group_input_channels,
235 bias + group * convolution->group_output_channels,
236 kernel_zero_points + group * convolution->group_output_channels,
237 (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
238 }
239 break;
240 case pytorch_qnnp_ukernel_type_conv: // The transpose can only be here
241 for (uint32_t group = 0; group < groups; group++) {
242 const uint8_t* const kernel_p = kernel +
243 group * convolution->group_output_channels * kernel_size *
244 convolution->group_input_channels;
245 const int32_t* const bias_p =
246 bias + group * convolution->group_output_channels;
247 if (convolution
248 ->transpose) { // Note that only runtime packing is here
249 pytorch_pack_q8deconv_wrq(
250 convolution->group_output_channels,
251 kernel_size,
252 convolution->group_input_channels,
253 nr,
254 kr,
255 kernel_p,
256 bias_p,
257 kernel_zero_points +
258 group * convolution->group_output_channels,
259 (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
260 } else {
261 pytorch_pack_q8conv_wrq(
262 convolution->group_output_channels,
263 kernel_size,
264 convolution->group_input_channels,
265 nr,
266 kr,
267 kernel_p,
268 bias_p,
269 kernel_zero_points +
270 group * convolution->group_output_channels,
271 (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size));
272 }
273 }
274 break;
275 default:
276 PYTORCH_QNNP_UNREACHABLE;
277 }
278 break;
279 }
280 default:
281 PYTORCH_QNNP_UNREACHABLE;
282 }
283 } // namespace qnnpack
284 } // namespace qnnpack
285