1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
15*4bdc9457SAndroid Build Coastguard Worker
16*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/indirection.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_ENABLE_GEMM_M_SPECIALIZATION
28*4bdc9457SAndroid Build Coastguard Worker #error "XNN_ENABLE_GEMM_M_SPECIALIZATION is not defined"
29*4bdc9457SAndroid Build Coastguard Worker #endif
30*4bdc9457SAndroid Build Coastguard Worker
create_deconvolution2d_nhwc(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_conv_goki_w_function pack_conv_goki_w,xnn_pack_deconv_goki_w_function pack_deconv_goki_w,const void * packing_params,int input_padding_byte,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,enum xnn_operator_type operator_type,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)31*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_deconvolution2d_nhwc(
32*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_top,
33*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_right,
34*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_bottom,
35*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_left,
36*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
37*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
38*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height,
39*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width,
40*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
41*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
42*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
43*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
44*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
45*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride,
46*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride,
47*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
48*4bdc9457SAndroid Build Coastguard Worker const void* bias,
49*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
50*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
51*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
52*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
53*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_goki_w_function pack_conv_goki_w,
54*4bdc9457SAndroid Build Coastguard Worker xnn_pack_deconv_goki_w_function pack_deconv_goki_w,
55*4bdc9457SAndroid Build Coastguard Worker const void* packing_params,
56*4bdc9457SAndroid Build Coastguard Worker int input_padding_byte,
57*4bdc9457SAndroid Build Coastguard Worker int packed_weights_padding_byte,
58*4bdc9457SAndroid Build Coastguard Worker const void* params,
59*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
60*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters,
61*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels,
62*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type operator_type,
63*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
64*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* deconvolution_op_out)
65*4bdc9457SAndroid Build Coastguard Worker {
66*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = NULL;
67*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
68*4bdc9457SAndroid Build Coastguard Worker
69*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
70*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
71*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
72*4bdc9457SAndroid Build Coastguard Worker goto error;
73*4bdc9457SAndroid Build Coastguard Worker }
74*4bdc9457SAndroid Build Coastguard Worker
75*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_invalid_parameter;
76*4bdc9457SAndroid Build Coastguard Worker
77*4bdc9457SAndroid Build Coastguard Worker if (kernel_width == 0 || kernel_height == 0) {
78*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
79*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
80*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), kernel_width, kernel_height);
81*4bdc9457SAndroid Build Coastguard Worker goto error;
82*4bdc9457SAndroid Build Coastguard Worker }
83*4bdc9457SAndroid Build Coastguard Worker
84*4bdc9457SAndroid Build Coastguard Worker if (stride_width == 0 || stride_height == 0) {
85*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
86*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
87*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), stride_width, stride_height);
88*4bdc9457SAndroid Build Coastguard Worker goto error;
89*4bdc9457SAndroid Build Coastguard Worker }
90*4bdc9457SAndroid Build Coastguard Worker
91*4bdc9457SAndroid Build Coastguard Worker if (dilation_width == 0 || dilation_height == 0) {
92*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
93*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
94*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), dilation_width, dilation_height);
95*4bdc9457SAndroid Build Coastguard Worker goto error;
96*4bdc9457SAndroid Build Coastguard Worker }
97*4bdc9457SAndroid Build Coastguard Worker
98*4bdc9457SAndroid Build Coastguard Worker if (groups == 0) {
99*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
100*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 " groups: number of groups must be non-zero",
101*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), groups);
102*4bdc9457SAndroid Build Coastguard Worker goto error;
103*4bdc9457SAndroid Build Coastguard Worker }
104*4bdc9457SAndroid Build Coastguard Worker
105*4bdc9457SAndroid Build Coastguard Worker if (group_input_channels == 0) {
106*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
107*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu input channels per group: number of channels must be non-zero",
108*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), group_input_channels);
109*4bdc9457SAndroid Build Coastguard Worker goto error;
110*4bdc9457SAndroid Build Coastguard Worker }
111*4bdc9457SAndroid Build Coastguard Worker
112*4bdc9457SAndroid Build Coastguard Worker if (group_output_channels == 0) {
113*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
114*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu output channels per group: number of channels must be non-zero",
115*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), group_output_channels);
116*4bdc9457SAndroid Build Coastguard Worker goto error;
117*4bdc9457SAndroid Build Coastguard Worker }
118*4bdc9457SAndroid Build Coastguard Worker
119*4bdc9457SAndroid Build Coastguard Worker const size_t input_channels = groups * group_input_channels;
120*4bdc9457SAndroid Build Coastguard Worker if (input_pixel_stride < input_channels) {
121*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
122*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with input pixel stride of %zu: "
123*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
124*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type),
125*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride, groups, group_input_channels);
126*4bdc9457SAndroid Build Coastguard Worker goto error;
127*4bdc9457SAndroid Build Coastguard Worker }
128*4bdc9457SAndroid Build Coastguard Worker
129*4bdc9457SAndroid Build Coastguard Worker const size_t output_channels = groups * group_output_channels;
130*4bdc9457SAndroid Build Coastguard Worker if (output_pixel_stride < output_channels) {
131*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
132*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with output pixel stride of %zu: "
133*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
134*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type),
135*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride, groups, group_output_channels);
136*4bdc9457SAndroid Build Coastguard Worker goto error;
137*4bdc9457SAndroid Build Coastguard Worker }
138*4bdc9457SAndroid Build Coastguard Worker
139*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
140*4bdc9457SAndroid Build Coastguard Worker
141*4bdc9457SAndroid Build Coastguard Worker deconvolution_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
142*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op == NULL) {
143*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
144*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator descriptor",
145*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
146*4bdc9457SAndroid Build Coastguard Worker goto error;
147*4bdc9457SAndroid Build Coastguard Worker }
148*4bdc9457SAndroid Build Coastguard Worker
149*4bdc9457SAndroid Build Coastguard Worker if (caches != NULL) {
150*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->weights_cache = caches->weights_cache;
151*4bdc9457SAndroid Build Coastguard Worker }
152*4bdc9457SAndroid Build Coastguard Worker
153*4bdc9457SAndroid Build Coastguard Worker const uint32_t mr = gemm_parameters->mr;
154*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = gemm_parameters->nr;
155*4bdc9457SAndroid Build Coastguard Worker const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
156*4bdc9457SAndroid Build Coastguard Worker const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
157*4bdc9457SAndroid Build Coastguard Worker
158*4bdc9457SAndroid Build Coastguard Worker const uint32_t n_stride = round_up(group_output_channels, nr);
159*4bdc9457SAndroid Build Coastguard Worker const uint32_t k_stride = round_up_po2(group_input_channels, kr * sr);
160*4bdc9457SAndroid Build Coastguard Worker const uint32_t kernel_size = kernel_height * kernel_width;
161*4bdc9457SAndroid Build Coastguard Worker enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_igemm;
162*4bdc9457SAndroid Build Coastguard Worker size_t packed_group_weights_size = (((kernel_size * k_stride) << log2_filter_element_size) + bias_element_size) * n_stride;
163*4bdc9457SAndroid Build Coastguard Worker if (max(stride_height, stride_width) > 1 && max(dilation_height, dilation_width) == 1 && stride_width <= kernel_width && stride_height <= kernel_height) {
164*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_subconv2d;
165*4bdc9457SAndroid Build Coastguard Worker const size_t subkernels = stride_height * stride_width;
166*4bdc9457SAndroid Build Coastguard Worker packed_group_weights_size = n_stride *
167*4bdc9457SAndroid Build Coastguard Worker (((kernel_size * k_stride) << log2_filter_element_size) + bias_element_size * subkernels);
168*4bdc9457SAndroid Build Coastguard Worker
169*4bdc9457SAndroid Build Coastguard Worker const size_t subconvolution_buffer_size = sizeof(struct subconvolution_params) * subkernels;
170*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->subconvolution_buffer = xnn_allocate_zero_memory(subconvolution_buffer_size);
171*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->subconvolution_buffer == NULL) {
172*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
173*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator subconvolution buffer",
174*4bdc9457SAndroid Build Coastguard Worker subconvolution_buffer_size, xnn_operator_type_to_string(operator_type));
175*4bdc9457SAndroid Build Coastguard Worker goto error;
176*4bdc9457SAndroid Build Coastguard Worker }
177*4bdc9457SAndroid Build Coastguard Worker
178*4bdc9457SAndroid Build Coastguard Worker struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
179*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
180*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
181*4bdc9457SAndroid Build Coastguard Worker const size_t subkernel_height = divide_round_up(kernel_height - offset_y, stride_height);
182*4bdc9457SAndroid Build Coastguard Worker const size_t subkernel_width = divide_round_up(kernel_width - offset_x, stride_width);
183*4bdc9457SAndroid Build Coastguard Worker const size_t subkernel_size = subkernel_height * subkernel_width;
184*4bdc9457SAndroid Build Coastguard Worker
185*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->indirection_x_stride = sizeof(void*) * subkernel_size;
186*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->w_stride = bias_element_size + ((k_stride * subkernel_size) << log2_filter_element_size);
187*4bdc9457SAndroid Build Coastguard Worker subconvolution_params++;
188*4bdc9457SAndroid Build Coastguard Worker }
189*4bdc9457SAndroid Build Coastguard Worker }
190*4bdc9457SAndroid Build Coastguard Worker }
191*4bdc9457SAndroid Build Coastguard Worker
192*4bdc9457SAndroid Build Coastguard Worker const size_t aligned_total_weights_size = round_up_po2(packed_group_weights_size * groups, XNN_ALLOCATION_ALIGNMENT);
193*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
194*4bdc9457SAndroid Build Coastguard Worker deconvolution_op, aligned_total_weights_size, packed_weights_padding_byte);
195*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
196*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
197*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator packed weights",
198*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size, xnn_operator_type_to_string(operator_type));
199*4bdc9457SAndroid Build Coastguard Worker goto error;
200*4bdc9457SAndroid Build Coastguard Worker }
201*4bdc9457SAndroid Build Coastguard Worker
202*4bdc9457SAndroid Build Coastguard Worker switch (ukernel_type) {
203*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_igemm:
204*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w(
205*4bdc9457SAndroid Build Coastguard Worker groups, group_output_channels, kernel_size, group_input_channels,
206*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
207*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr,
208*4bdc9457SAndroid Build Coastguard Worker 0 /* extra bytes */,
209*4bdc9457SAndroid Build Coastguard Worker packing_params);
210*4bdc9457SAndroid Build Coastguard Worker break;
211*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_subconv2d:
212*4bdc9457SAndroid Build Coastguard Worker pack_deconv_goki_w(
213*4bdc9457SAndroid Build Coastguard Worker groups, group_output_channels, kernel_height, kernel_width, group_input_channels,
214*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width,
215*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
216*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, deconvolution_op->subconvolution_buffer,
217*4bdc9457SAndroid Build Coastguard Worker packing_params);
218*4bdc9457SAndroid Build Coastguard Worker // We assume that the first subconvolution param weights point to the start of the weights, this is used to check
219*4bdc9457SAndroid Build Coastguard Worker // if the weights cache has moved.
220*4bdc9457SAndroid Build Coastguard Worker assert(deconvolution_op->subconvolution_buffer->weights == weights_ptr);
221*4bdc9457SAndroid Build Coastguard Worker break;
222*4bdc9457SAndroid Build Coastguard Worker default:
223*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
224*4bdc9457SAndroid Build Coastguard Worker }
225*4bdc9457SAndroid Build Coastguard Worker
226*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(deconvolution_op)) {
227*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
228*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
229*4bdc9457SAndroid Build Coastguard Worker }
230*4bdc9457SAndroid Build Coastguard Worker
231*4bdc9457SAndroid Build Coastguard Worker const size_t zero_size = (k_stride << log2_input_element_size) + XNN_EXTRA_BYTES;
232*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->zero_buffer = xnn_allocate_simd_memory(zero_size);
233*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->zero_buffer == NULL) {
234*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
235*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator zero padding",
236*4bdc9457SAndroid Build Coastguard Worker zero_size, xnn_operator_type_to_string(operator_type));
237*4bdc9457SAndroid Build Coastguard Worker goto error;
238*4bdc9457SAndroid Build Coastguard Worker }
239*4bdc9457SAndroid Build Coastguard Worker memset(deconvolution_op->zero_buffer, input_padding_byte, zero_size);
240*4bdc9457SAndroid Build Coastguard Worker
241*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->padding_top = output_padding_top;
242*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->padding_right = output_padding_right;
243*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->padding_bottom = output_padding_bottom;
244*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->padding_left = output_padding_left;
245*4bdc9457SAndroid Build Coastguard Worker
246*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->kernel_height = kernel_height;
247*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->kernel_width = kernel_width;
248*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->stride_height = stride_height;
249*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->stride_width = stride_width;
250*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->dilation_height = dilation_height;
251*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->dilation_width = dilation_width;
252*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->groups = groups;
253*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->group_input_channels = group_input_channels;
254*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->group_output_channels = group_output_channels;
255*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->input_pixel_stride = input_pixel_stride;
256*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output_pixel_stride = output_pixel_stride;
257*4bdc9457SAndroid Build Coastguard Worker
258*4bdc9457SAndroid Build Coastguard Worker memcpy(&deconvolution_op->params, params, params_size);
259*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->type = operator_type;
260*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->ukernel.type = ukernel_type;
261*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
262*4bdc9457SAndroid Build Coastguard Worker .mr = mr,
263*4bdc9457SAndroid Build Coastguard Worker .nr = nr,
264*4bdc9457SAndroid Build Coastguard Worker .kr = kr,
265*4bdc9457SAndroid Build Coastguard Worker .sr = sr,
266*4bdc9457SAndroid Build Coastguard Worker };
267*4bdc9457SAndroid Build Coastguard Worker
268*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= mr);
269*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < mr; i++) {
270*4bdc9457SAndroid Build Coastguard Worker if (gemm_ukernels->gemm[i].function[XNN_UARCH_DEFAULT] != NULL) {
271*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->ukernel.igemm.gemm_cases[i] = gemm_ukernels->gemm[i];
272*4bdc9457SAndroid Build Coastguard Worker }
273*4bdc9457SAndroid Build Coastguard Worker if (gemm_ukernels->igemm[i].function[XNN_UARCH_DEFAULT] != NULL) {
274*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->ukernel.igemm.igemm_cases[i] = gemm_ukernels->igemm[i];
275*4bdc9457SAndroid Build Coastguard Worker }
276*4bdc9457SAndroid Build Coastguard Worker }
277*4bdc9457SAndroid Build Coastguard Worker
278*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->state = xnn_run_state_invalid;
279*4bdc9457SAndroid Build Coastguard Worker
280*4bdc9457SAndroid Build Coastguard Worker *deconvolution_op_out = deconvolution_op;
281*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
282*4bdc9457SAndroid Build Coastguard Worker
283*4bdc9457SAndroid Build Coastguard Worker error:
284*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(deconvolution_op);
285*4bdc9457SAndroid Build Coastguard Worker return status;
286*4bdc9457SAndroid Build Coastguard Worker }
287*4bdc9457SAndroid Build Coastguard Worker
xnn_create_deconvolution2d_nhwc_qs8(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)288*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_deconvolution2d_nhwc_qs8(
289*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_top,
290*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_right,
291*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_bottom,
292*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_left,
293*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
294*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
295*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height,
296*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width,
297*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
298*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
299*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
300*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
301*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
302*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride,
303*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride,
304*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
305*4bdc9457SAndroid Build Coastguard Worker float input_scale,
306*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
307*4bdc9457SAndroid Build Coastguard Worker const int8_t* kernel,
308*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
309*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
310*4bdc9457SAndroid Build Coastguard Worker float output_scale,
311*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
312*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
313*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
314*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
315*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* deconvolution_op_out)
316*4bdc9457SAndroid Build Coastguard Worker {
317*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
318*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
319*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
320*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), input_scale);
321*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker
324*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
325*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
326*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
327*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), kernel_scale);
328*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
329*4bdc9457SAndroid Build Coastguard Worker }
330*4bdc9457SAndroid Build Coastguard Worker
331*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
332*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
333*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
334*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), output_scale);
335*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
336*4bdc9457SAndroid Build Coastguard Worker }
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
339*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
340*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
341*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), output_min, output_max);
342*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
343*4bdc9457SAndroid Build Coastguard Worker }
344*4bdc9457SAndroid Build Coastguard Worker
345*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
346*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
347*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
348*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
349*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
350*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8),
351*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
352*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
353*4bdc9457SAndroid Build Coastguard Worker }
354*4bdc9457SAndroid Build Coastguard Worker
355*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params params;
356*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
357*4bdc9457SAndroid Build Coastguard Worker xnn_params.qs8.gemm.init.qs8(¶ms,
358*4bdc9457SAndroid Build Coastguard Worker requantization_scale, output_zero_point, output_min, output_max);
359*4bdc9457SAndroid Build Coastguard Worker }
360*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qs8_packing_params packing_params = {
361*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = input_zero_point,
362*4bdc9457SAndroid Build Coastguard Worker };
363*4bdc9457SAndroid Build Coastguard Worker return create_deconvolution2d_nhwc(
364*4bdc9457SAndroid Build Coastguard Worker output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
365*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
366*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width,
367*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
368*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
369*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride, output_pixel_stride,
370*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
371*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
372*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
373*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
374*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_qs8_conv_goki_w,
375*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_deconv_goki_w_function) xnn_pack_qs8_deconv_goki_w,
376*4bdc9457SAndroid Build Coastguard Worker &packing_params, input_zero_point /* input padding byte */, 0 /* packed weights padding byte */,
377*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
378*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
379*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_deconvolution_nhwc_qs8,
380*4bdc9457SAndroid Build Coastguard Worker caches,
381*4bdc9457SAndroid Build Coastguard Worker deconvolution_op_out);
382*4bdc9457SAndroid Build Coastguard Worker }
383*4bdc9457SAndroid Build Coastguard Worker
xnn_create_deconvolution2d_nhwc_qu8(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)384*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_deconvolution2d_nhwc_qu8(
385*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_top,
386*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_right,
387*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_bottom,
388*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_left,
389*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
390*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
391*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height,
392*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width,
393*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
394*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
395*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
396*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
397*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
398*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride,
399*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride,
400*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point,
401*4bdc9457SAndroid Build Coastguard Worker float input_scale,
402*4bdc9457SAndroid Build Coastguard Worker uint8_t kernel_zero_point,
403*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
404*4bdc9457SAndroid Build Coastguard Worker const uint8_t* kernel,
405*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
406*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point,
407*4bdc9457SAndroid Build Coastguard Worker float output_scale,
408*4bdc9457SAndroid Build Coastguard Worker uint8_t output_min,
409*4bdc9457SAndroid Build Coastguard Worker uint8_t output_max,
410*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
411*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
412*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* deconvolution_op_out)
413*4bdc9457SAndroid Build Coastguard Worker {
414*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
415*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
416*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
417*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), input_scale);
418*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
419*4bdc9457SAndroid Build Coastguard Worker }
420*4bdc9457SAndroid Build Coastguard Worker
421*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
422*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
423*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
424*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), kernel_scale);
425*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
426*4bdc9457SAndroid Build Coastguard Worker }
427*4bdc9457SAndroid Build Coastguard Worker
428*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
429*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
430*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
431*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), output_scale);
432*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker
435*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
436*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
437*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
438*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), output_min, output_max);
439*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
440*4bdc9457SAndroid Build Coastguard Worker }
441*4bdc9457SAndroid Build Coastguard Worker
442*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
443*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
444*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
445*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
446*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
447*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8),
448*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
449*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
450*4bdc9457SAndroid Build Coastguard Worker }
451*4bdc9457SAndroid Build Coastguard Worker
452*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params params;
453*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
454*4bdc9457SAndroid Build Coastguard Worker xnn_params.qu8.gemm.init.qu8(¶ms,
455*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
456*4bdc9457SAndroid Build Coastguard Worker }
457*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qu8_packing_params packing_params = {
458*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = input_zero_point,
459*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = kernel_zero_point,
460*4bdc9457SAndroid Build Coastguard Worker };
461*4bdc9457SAndroid Build Coastguard Worker return create_deconvolution2d_nhwc(
462*4bdc9457SAndroid Build Coastguard Worker output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
463*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
464*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width,
465*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
466*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
467*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride, output_pixel_stride,
468*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
469*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
470*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
471*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
472*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_qu8_conv_goki_w,
473*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_deconv_goki_w_function) xnn_pack_qu8_deconv_goki_w,
474*4bdc9457SAndroid Build Coastguard Worker &packing_params, input_zero_point /* input padding byte */, kernel_zero_point /* packed weights padding byte */,
475*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
476*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
477*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_deconvolution_nhwc_qu8,
478*4bdc9457SAndroid Build Coastguard Worker caches,
479*4bdc9457SAndroid Build Coastguard Worker deconvolution_op_out);
480*4bdc9457SAndroid Build Coastguard Worker }
481*4bdc9457SAndroid Build Coastguard Worker
xnn_create_deconvolution2d_nhwc_f16(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)482*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_deconvolution2d_nhwc_f16(
483*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_top,
484*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_right,
485*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_bottom,
486*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_left,
487*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
488*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
489*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height,
490*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width,
491*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
492*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
493*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
494*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
495*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
496*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride,
497*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride,
498*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
499*4bdc9457SAndroid Build Coastguard Worker const void* bias,
500*4bdc9457SAndroid Build Coastguard Worker float output_min,
501*4bdc9457SAndroid Build Coastguard Worker float output_max,
502*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
503*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
504*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* deconvolution_op_out)
505*4bdc9457SAndroid Build Coastguard Worker {
506*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_F16) != XNN_INIT_FLAG_F16) {
507*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create %s operator: operations on data type are not supported",
508*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
509*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_hardware;
510*4bdc9457SAndroid Build Coastguard Worker }
511*4bdc9457SAndroid Build Coastguard Worker
512*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
513*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
514*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
515*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
516*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
517*4bdc9457SAndroid Build Coastguard Worker }
518*4bdc9457SAndroid Build Coastguard Worker
519*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
520*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
521*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
522*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
523*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
524*4bdc9457SAndroid Build Coastguard Worker }
525*4bdc9457SAndroid Build Coastguard Worker
526*4bdc9457SAndroid Build Coastguard Worker const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min);
527*4bdc9457SAndroid Build Coastguard Worker const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max);
528*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(output_min_as_half);
529*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(output_max_as_half);
530*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
531*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
532*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
533*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16), output_min, output_max);
534*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
535*4bdc9457SAndroid Build Coastguard Worker }
536*4bdc9457SAndroid Build Coastguard Worker
537*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters = &xnn_params.f16.gemm;
538*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels = &gemm_parameters->minmax;
539*4bdc9457SAndroid Build Coastguard Worker const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
540*4bdc9457SAndroid Build Coastguard Worker if (linear_activation && gemm_parameters->linear.gemm[gemm_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
541*4bdc9457SAndroid Build Coastguard Worker gemm_ukernels = &gemm_parameters->linear;
542*4bdc9457SAndroid Build Coastguard Worker }
543*4bdc9457SAndroid Build Coastguard Worker
544*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params params;
545*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
546*4bdc9457SAndroid Build Coastguard Worker gemm_parameters->init.f16(¶ms, output_min_as_half, output_max_as_half);
547*4bdc9457SAndroid Build Coastguard Worker }
548*4bdc9457SAndroid Build Coastguard Worker
549*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_goki_w_function pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f16_conv_goki_w;
550*4bdc9457SAndroid Build Coastguard Worker xnn_pack_deconv_goki_w_function pack_deconv_goki_w = (xnn_pack_deconv_goki_w_function) xnn_pack_f16_deconv_goki_w;
551*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
552*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f32_to_f16_conv_goki_w;
553*4bdc9457SAndroid Build Coastguard Worker pack_deconv_goki_w = (xnn_pack_deconv_goki_w_function) xnn_pack_f32_to_f16_deconv_goki_w;
554*4bdc9457SAndroid Build Coastguard Worker }
555*4bdc9457SAndroid Build Coastguard Worker
556*4bdc9457SAndroid Build Coastguard Worker return create_deconvolution2d_nhwc(
557*4bdc9457SAndroid Build Coastguard Worker output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
558*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
559*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width,
560*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
561*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
562*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride, output_pixel_stride,
563*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
564*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
565*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
566*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(bias element) */,
567*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w,
568*4bdc9457SAndroid Build Coastguard Worker pack_deconv_goki_w,
569*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
570*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
571*4bdc9457SAndroid Build Coastguard Worker gemm_parameters, gemm_ukernels,
572*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_deconvolution_nhwc_f16,
573*4bdc9457SAndroid Build Coastguard Worker caches,
574*4bdc9457SAndroid Build Coastguard Worker deconvolution_op_out);
575*4bdc9457SAndroid Build Coastguard Worker }
576*4bdc9457SAndroid Build Coastguard Worker
xnn_create_deconvolution2d_nhwc_f32(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)577*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
578*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_top,
579*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_right,
580*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_bottom,
581*4bdc9457SAndroid Build Coastguard Worker uint32_t output_padding_left,
582*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
583*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
584*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height,
585*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width,
586*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
587*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
588*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
589*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
590*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
591*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride,
592*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride,
593*4bdc9457SAndroid Build Coastguard Worker const float* kernel,
594*4bdc9457SAndroid Build Coastguard Worker const float* bias,
595*4bdc9457SAndroid Build Coastguard Worker float output_min,
596*4bdc9457SAndroid Build Coastguard Worker float output_max,
597*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
598*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
599*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* deconvolution_op_out)
600*4bdc9457SAndroid Build Coastguard Worker {
601*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
602*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
603*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
604*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32));
605*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
606*4bdc9457SAndroid Build Coastguard Worker }
607*4bdc9457SAndroid Build Coastguard Worker
608*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
609*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
610*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
611*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32));
612*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
613*4bdc9457SAndroid Build Coastguard Worker }
614*4bdc9457SAndroid Build Coastguard Worker
615*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
616*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
617*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
618*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32), output_min, output_max);
619*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
620*4bdc9457SAndroid Build Coastguard Worker }
621*4bdc9457SAndroid Build Coastguard Worker
622*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters = &xnn_params.f32.gemm;
623*4bdc9457SAndroid Build Coastguard Worker if (gemm_parameters->nr > group_output_channels) {
624*4bdc9457SAndroid Build Coastguard Worker // Default micro-kernel is suboptimal. Try to find a better micro-kernel.
625*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm2_parameters = &xnn_params.f32.gemm2;
626*4bdc9457SAndroid Build Coastguard Worker if (gemm2_parameters->minmax.igemm[gemm2_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
627*4bdc9457SAndroid Build Coastguard Worker gemm_parameters = gemm2_parameters;
628*4bdc9457SAndroid Build Coastguard Worker }
629*4bdc9457SAndroid Build Coastguard Worker }
630*4bdc9457SAndroid Build Coastguard Worker
631*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels = &gemm_parameters->minmax;
632*4bdc9457SAndroid Build Coastguard Worker const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
633*4bdc9457SAndroid Build Coastguard Worker if (linear_activation && gemm_parameters->linear.gemm[gemm_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
634*4bdc9457SAndroid Build Coastguard Worker gemm_ukernels = &gemm_parameters->linear;
635*4bdc9457SAndroid Build Coastguard Worker }
636*4bdc9457SAndroid Build Coastguard Worker
637*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params params;
638*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
639*4bdc9457SAndroid Build Coastguard Worker gemm_parameters->init.f32(¶ms, output_min, output_max);
640*4bdc9457SAndroid Build Coastguard Worker }
641*4bdc9457SAndroid Build Coastguard Worker return create_deconvolution2d_nhwc(
642*4bdc9457SAndroid Build Coastguard Worker output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
643*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
644*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width,
645*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
646*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
647*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride, output_pixel_stride,
648*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
649*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
650*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
651*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
652*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_f32_conv_goki_w,
653*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_deconv_goki_w_function) xnn_pack_f32_deconv_goki_w,
654*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
655*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
656*4bdc9457SAndroid Build Coastguard Worker gemm_parameters, gemm_ukernels,
657*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_deconvolution_nhwc_f32,
658*4bdc9457SAndroid Build Coastguard Worker caches,
659*4bdc9457SAndroid Build Coastguard Worker deconvolution_op_out);
660*4bdc9457SAndroid Build Coastguard Worker }
661*4bdc9457SAndroid Build Coastguard Worker
setup_conv_path(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,size_t output_height,size_t output_width,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)662*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_conv_path(
663*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
664*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
665*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
666*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
667*4bdc9457SAndroid Build Coastguard Worker const void* input,
668*4bdc9457SAndroid Build Coastguard Worker size_t output_height,
669*4bdc9457SAndroid Build Coastguard Worker size_t output_width,
670*4bdc9457SAndroid Build Coastguard Worker void* output,
671*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
672*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
673*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
674*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
675*4bdc9457SAndroid Build Coastguard Worker const void* params,
676*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
677*4bdc9457SAndroid Build Coastguard Worker size_t num_threads)
678*4bdc9457SAndroid Build Coastguard Worker {
679*4bdc9457SAndroid Build Coastguard Worker assert(deconvolution_op->ukernel.type == xnn_ukernel_type_igemm);
680*4bdc9457SAndroid Build Coastguard Worker
681*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_height = deconvolution_op->kernel_height;
682*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_width = deconvolution_op->kernel_width;
683*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_size = kernel_height * kernel_width;
684*4bdc9457SAndroid Build Coastguard Worker
685*4bdc9457SAndroid Build Coastguard Worker const size_t groups = deconvolution_op->groups;
686*4bdc9457SAndroid Build Coastguard Worker const size_t output_size = output_height * output_width;
687*4bdc9457SAndroid Build Coastguard Worker size_t mr = deconvolution_op->ukernel.igemm.mr;
688*4bdc9457SAndroid Build Coastguard Worker
689*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_igemm_ukernel igemm_ukernel = deconvolution_op->ukernel.igemm.igemm_cases[mr - 1];
690*4bdc9457SAndroid Build Coastguard Worker if (output_size == 1 && deconvolution_op->ukernel.igemm.igemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
691*4bdc9457SAndroid Build Coastguard Worker mr = 1;
692*4bdc9457SAndroid Build Coastguard Worker igemm_ukernel = deconvolution_op->ukernel.igemm.igemm_cases[0];
693*4bdc9457SAndroid Build Coastguard Worker }
694*4bdc9457SAndroid Build Coastguard Worker
695*4bdc9457SAndroid Build Coastguard Worker const size_t tiled_output_size = round_up(output_size, mr);
696*4bdc9457SAndroid Build Coastguard Worker const size_t indirection_buffer_size = sizeof(void*) * kernel_size * tiled_output_size;
697*4bdc9457SAndroid Build Coastguard Worker
698*4bdc9457SAndroid Build Coastguard Worker if (input_height != deconvolution_op->last_input_height ||
699*4bdc9457SAndroid Build Coastguard Worker input_width != deconvolution_op->last_input_width)
700*4bdc9457SAndroid Build Coastguard Worker {
701*4bdc9457SAndroid Build Coastguard Worker const void** indirection_buffer = (const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
702*4bdc9457SAndroid Build Coastguard Worker if (indirection_buffer == NULL) {
703*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
704*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator indirection buffer",
705*4bdc9457SAndroid Build Coastguard Worker indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
706*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
707*4bdc9457SAndroid Build Coastguard Worker }
708*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->indirection_buffer = indirection_buffer;
709*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input = input;
710*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input_height = input_height;
711*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input_width = input_width;
712*4bdc9457SAndroid Build Coastguard Worker
713*4bdc9457SAndroid Build Coastguard Worker xnn_indirection_init_deconv2d(deconvolution_op, mr, log2_input_element_size);
714*4bdc9457SAndroid Build Coastguard Worker }
715*4bdc9457SAndroid Build Coastguard Worker
716*4bdc9457SAndroid Build Coastguard Worker const size_t group_input_channels = deconvolution_op->group_input_channels;
717*4bdc9457SAndroid Build Coastguard Worker const size_t group_output_channels = deconvolution_op->group_output_channels;
718*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
719*4bdc9457SAndroid Build Coastguard Worker
720*4bdc9457SAndroid Build Coastguard Worker const size_t w_stride = bias_element_size +
721*4bdc9457SAndroid Build Coastguard Worker (round_up_po2(group_input_channels, deconvolution_op->ukernel.igemm.kr * deconvolution_op->ukernel.igemm.sr) * kernel_size << log2_filter_element_size);
722*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->context.igemm = (struct igemm_context) {
723*4bdc9457SAndroid Build Coastguard Worker .ks = kernel_size,
724*4bdc9457SAndroid Build Coastguard Worker .ks_scaled = kernel_size * mr * sizeof(void*),
725*4bdc9457SAndroid Build Coastguard Worker .kc = group_input_channels << log2_input_element_size,
726*4bdc9457SAndroid Build Coastguard Worker .w_stride = w_stride,
727*4bdc9457SAndroid Build Coastguard Worker .indirect_a = deconvolution_op->indirection_buffer,
728*4bdc9457SAndroid Build Coastguard Worker .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
729*4bdc9457SAndroid Build Coastguard Worker .zero = deconvolution_op->zero_buffer,
730*4bdc9457SAndroid Build Coastguard Worker .packed_w = packed_weights(deconvolution_op),
731*4bdc9457SAndroid Build Coastguard Worker .c = deconvolution_op->output,
732*4bdc9457SAndroid Build Coastguard Worker .cm_stride = deconvolution_op->output_pixel_stride << log2_output_element_size,
733*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
734*4bdc9457SAndroid Build Coastguard Worker .ga_stride = group_input_channels << log2_input_element_size,
735*4bdc9457SAndroid Build Coastguard Worker .gw_stride = w_stride * round_up(group_output_channels, nr),
736*4bdc9457SAndroid Build Coastguard Worker .gc_stride = group_output_channels << log2_output_element_size,
737*4bdc9457SAndroid Build Coastguard Worker .ba_stride = input_height * input_width * deconvolution_op->input_pixel_stride << log2_input_element_size,
738*4bdc9457SAndroid Build Coastguard Worker .bc_stride = output_size * deconvolution_op->output_pixel_stride << log2_output_element_size,
739*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
740*4bdc9457SAndroid Build Coastguard Worker .ukernel = igemm_ukernel,
741*4bdc9457SAndroid Build Coastguard Worker };
742*4bdc9457SAndroid Build Coastguard Worker memcpy(&deconvolution_op->context.igemm.params, params, params_size);
743*4bdc9457SAndroid Build Coastguard Worker
744*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
745*4bdc9457SAndroid Build Coastguard Worker const size_t nc = nr;
746*4bdc9457SAndroid Build Coastguard Worker #else
747*4bdc9457SAndroid Build Coastguard Worker size_t nc = group_output_channels;
748*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
749*4bdc9457SAndroid Build Coastguard Worker const size_t num_other_tiles = groups * batch_size * divide_round_up(output_size, mr);
750*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
751*4bdc9457SAndroid Build Coastguard Worker const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
752*4bdc9457SAndroid Build Coastguard Worker if (max_nc < nc) {
753*4bdc9457SAndroid Build Coastguard Worker nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
754*4bdc9457SAndroid Build Coastguard Worker }
755*4bdc9457SAndroid Build Coastguard Worker }
756*4bdc9457SAndroid Build Coastguard Worker #endif
757*4bdc9457SAndroid Build Coastguard Worker if (groups == 1) {
758*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
759*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
760*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
761*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
762*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_batch_hmp_igemm;
763*4bdc9457SAndroid Build Coastguard Worker } else {
764*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
765*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_igemm;
766*4bdc9457SAndroid Build Coastguard Worker }
767*4bdc9457SAndroid Build Coastguard Worker } else {
768*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
769*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
770*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
771*4bdc9457SAndroid Build Coastguard Worker } else {
772*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
773*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
774*4bdc9457SAndroid Build Coastguard Worker }
775*4bdc9457SAndroid Build Coastguard Worker }
776*4bdc9457SAndroid Build Coastguard Worker #else
777*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
778*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
779*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
780*4bdc9457SAndroid Build Coastguard Worker } else {
781*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
782*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
783*4bdc9457SAndroid Build Coastguard Worker }
784*4bdc9457SAndroid Build Coastguard Worker #endif
785*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
786*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = batch_size;
787*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = output_size;
788*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[2] = group_output_channels;
789*4bdc9457SAndroid Build Coastguard Worker } else {
790*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = output_size;
791*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = group_output_channels;
792*4bdc9457SAndroid Build Coastguard Worker }
793*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[0] = mr;
794*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[1] = nc;
795*4bdc9457SAndroid Build Coastguard Worker } else {
796*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
797*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
798*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
799*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d_with_uarch;
800*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_4d_tile_2d_with_id = (pthreadpool_task_4d_tile_2d_with_id_t) xnn_compute_hmp_grouped_batch_igemm;
801*4bdc9457SAndroid Build Coastguard Worker } else {
802*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
803*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_igemm;
804*4bdc9457SAndroid Build Coastguard Worker }
805*4bdc9457SAndroid Build Coastguard Worker } else {
806*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
807*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
808*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
809*4bdc9457SAndroid Build Coastguard Worker } else {
810*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
811*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
812*4bdc9457SAndroid Build Coastguard Worker }
813*4bdc9457SAndroid Build Coastguard Worker }
814*4bdc9457SAndroid Build Coastguard Worker #else
815*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
816*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
817*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
818*4bdc9457SAndroid Build Coastguard Worker } else {
819*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
820*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
821*4bdc9457SAndroid Build Coastguard Worker }
822*4bdc9457SAndroid Build Coastguard Worker #endif
823*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
824*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = batch_size;
825*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = groups;
826*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[2] = output_size;
827*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[3] = group_output_channels;
828*4bdc9457SAndroid Build Coastguard Worker } else {
829*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = groups;
830*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = output_size;
831*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[2] = group_output_channels;
832*4bdc9457SAndroid Build Coastguard Worker }
833*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[0] = mr;
834*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[1] = nc;
835*4bdc9457SAndroid Build Coastguard Worker }
836*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->state = xnn_run_state_ready;
837*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
838*4bdc9457SAndroid Build Coastguard Worker }
839*4bdc9457SAndroid Build Coastguard Worker
setup_subconv2d_path(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,size_t output_height,size_t output_width,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads,bool use_gemm)840*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_subconv2d_path(
841*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
842*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
843*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
844*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
845*4bdc9457SAndroid Build Coastguard Worker const void* input,
846*4bdc9457SAndroid Build Coastguard Worker size_t output_height,
847*4bdc9457SAndroid Build Coastguard Worker size_t output_width,
848*4bdc9457SAndroid Build Coastguard Worker void* output,
849*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
850*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
851*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
852*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
853*4bdc9457SAndroid Build Coastguard Worker const void* params,
854*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
855*4bdc9457SAndroid Build Coastguard Worker size_t num_threads,
856*4bdc9457SAndroid Build Coastguard Worker bool use_gemm)
857*4bdc9457SAndroid Build Coastguard Worker {
858*4bdc9457SAndroid Build Coastguard Worker assert(deconvolution_op->ukernel.type == xnn_ukernel_type_subconv2d);
859*4bdc9457SAndroid Build Coastguard Worker
860*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_height = deconvolution_op->kernel_height;
861*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_width = deconvolution_op->kernel_width;
862*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_size = kernel_height * kernel_width;
863*4bdc9457SAndroid Build Coastguard Worker const size_t stride_height = deconvolution_op->stride_height;
864*4bdc9457SAndroid Build Coastguard Worker const size_t stride_width = deconvolution_op->stride_width;
865*4bdc9457SAndroid Build Coastguard Worker const size_t output_height_positions = divide_round_up(output_height, stride_height);
866*4bdc9457SAndroid Build Coastguard Worker const size_t output_width_positions = divide_round_up(output_width, stride_width);
867*4bdc9457SAndroid Build Coastguard Worker
868*4bdc9457SAndroid Build Coastguard Worker const size_t groups = deconvolution_op->groups;
869*4bdc9457SAndroid Build Coastguard Worker const size_t output_size = output_height * output_width;
870*4bdc9457SAndroid Build Coastguard Worker uint32_t mr = deconvolution_op->ukernel.igemm.mr;
871*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
872*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_GEMM_M_SPECIALIZATION
873*4bdc9457SAndroid Build Coastguard Worker mr = xnn_get_heuristic_mr_igemm(
874*4bdc9457SAndroid Build Coastguard Worker output_width_positions, mr, nr, deconvolution_op->ukernel.igemm.igemm_cases);
875*4bdc9457SAndroid Build Coastguard Worker #endif
876*4bdc9457SAndroid Build Coastguard Worker
877*4bdc9457SAndroid Build Coastguard Worker const size_t input_pixel_stride = deconvolution_op->input_pixel_stride << log2_input_element_size;
878*4bdc9457SAndroid Build Coastguard Worker const size_t output_pixel_stride = deconvolution_op->output_pixel_stride << log2_output_element_size;
879*4bdc9457SAndroid Build Coastguard Worker
880*4bdc9457SAndroid Build Coastguard Worker const bool any_size_change =
881*4bdc9457SAndroid Build Coastguard Worker input_height != deconvolution_op->last_input_height ||
882*4bdc9457SAndroid Build Coastguard Worker input_width != deconvolution_op->last_input_width ||
883*4bdc9457SAndroid Build Coastguard Worker output_height != deconvolution_op->last_output_height ||
884*4bdc9457SAndroid Build Coastguard Worker output_width != deconvolution_op->last_output_width;
885*4bdc9457SAndroid Build Coastguard Worker
886*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->weights_cache != NULL) {
887*4bdc9457SAndroid Build Coastguard Worker void* packed_weights_ptr = packed_weights(deconvolution_op);
888*4bdc9457SAndroid Build Coastguard Worker struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
889*4bdc9457SAndroid Build Coastguard Worker if (packed_weights_ptr != subconvolution_params->weights) {
890*4bdc9457SAndroid Build Coastguard Worker // Weights cache moved, update all weights pointer.
891*4bdc9457SAndroid Build Coastguard Worker const ptrdiff_t diff = (uintptr_t) packed_weights_ptr - (uintptr_t) subconvolution_params->weights;
892*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
893*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
894*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->weights = (void*) ((uintptr_t) subconvolution_params->weights + diff);
895*4bdc9457SAndroid Build Coastguard Worker ++subconvolution_params;
896*4bdc9457SAndroid Build Coastguard Worker }
897*4bdc9457SAndroid Build Coastguard Worker }
898*4bdc9457SAndroid Build Coastguard Worker }
899*4bdc9457SAndroid Build Coastguard Worker }
900*4bdc9457SAndroid Build Coastguard Worker
901*4bdc9457SAndroid Build Coastguard Worker if (any_size_change || output != deconvolution_op->last_output) {
902*4bdc9457SAndroid Build Coastguard Worker // Initialize subconvolution parameters which depend on output dimensions or MR.
903*4bdc9457SAndroid Build Coastguard Worker struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
904*4bdc9457SAndroid Build Coastguard Worker const size_t modulo_padding_top = deconvolution_op->padding_top % stride_height;
905*4bdc9457SAndroid Build Coastguard Worker const size_t modulo_padding_left = deconvolution_op->padding_left % stride_width;
906*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
907*4bdc9457SAndroid Build Coastguard Worker for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
908*4bdc9457SAndroid Build Coastguard Worker const size_t output_x_start = subtract_modulo(offset_x, modulo_padding_left, stride_width);
909*4bdc9457SAndroid Build Coastguard Worker const size_t output_y_start = subtract_modulo(offset_y, modulo_padding_top, stride_height);
910*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->scaled_kernel_size = mr * subconvolution_params->indirection_x_stride;
911*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->slice_width = divide_round_up(output_width - output_x_start, stride_width);
912*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->slice_height = divide_round_up(output_height - output_y_start, stride_height);
913*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->output =
914*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) output + ((output_y_start * output_width + output_x_start) * output_pixel_stride));
915*4bdc9457SAndroid Build Coastguard Worker ++subconvolution_params;
916*4bdc9457SAndroid Build Coastguard Worker }
917*4bdc9457SAndroid Build Coastguard Worker }
918*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_output = output;
919*4bdc9457SAndroid Build Coastguard Worker }
920*4bdc9457SAndroid Build Coastguard Worker
921*4bdc9457SAndroid Build Coastguard Worker if (any_size_change) {
922*4bdc9457SAndroid Build Coastguard Worker if (!use_gemm) {
923*4bdc9457SAndroid Build Coastguard Worker const size_t indirection_buffer_size = sizeof(void*) *
924*4bdc9457SAndroid Build Coastguard Worker kernel_size * output_height * stride_width * round_up(divide_round_up(output_width, stride_width), mr);
925*4bdc9457SAndroid Build Coastguard Worker
926*4bdc9457SAndroid Build Coastguard Worker const void** indirection_buffer =
927*4bdc9457SAndroid Build Coastguard Worker (const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
928*4bdc9457SAndroid Build Coastguard Worker if (indirection_buffer == NULL) {
929*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
930*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator indirection buffer",
931*4bdc9457SAndroid Build Coastguard Worker indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
932*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
933*4bdc9457SAndroid Build Coastguard Worker }
934*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->indirection_buffer = indirection_buffer;
935*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input = input;
936*4bdc9457SAndroid Build Coastguard Worker
937*4bdc9457SAndroid Build Coastguard Worker xnn_indirection_init_subconv2d(deconvolution_op, mr, log2_input_element_size);
938*4bdc9457SAndroid Build Coastguard Worker }
939*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input_height = input_height;
940*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_input_width = input_width;
941*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_output_height = output_height;
942*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->last_output_width = output_width;
943*4bdc9457SAndroid Build Coastguard Worker }
944*4bdc9457SAndroid Build Coastguard Worker
945*4bdc9457SAndroid Build Coastguard Worker const size_t group_input_channels = deconvolution_op->group_input_channels;
946*4bdc9457SAndroid Build Coastguard Worker const size_t group_output_channels = deconvolution_op->group_output_channels;
947*4bdc9457SAndroid Build Coastguard Worker const uint32_t kr = deconvolution_op->ukernel.igemm.kr;
948*4bdc9457SAndroid Build Coastguard Worker const uint32_t sr = deconvolution_op->ukernel.igemm.sr;
949*4bdc9457SAndroid Build Coastguard Worker const size_t w_stride = stride_height * stride_width * bias_element_size +
950*4bdc9457SAndroid Build Coastguard Worker (round_up_po2(group_input_channels, kr * sr) * kernel_size << log2_filter_element_size);
951*4bdc9457SAndroid Build Coastguard Worker if (use_gemm) {
952*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->context.subgemm = (struct subgemm_context) {
953*4bdc9457SAndroid Build Coastguard Worker .subconvolution_params = deconvolution_op->subconvolution_buffer,
954*4bdc9457SAndroid Build Coastguard Worker .kc = group_input_channels << log2_input_element_size,
955*4bdc9457SAndroid Build Coastguard Worker .a = input,
956*4bdc9457SAndroid Build Coastguard Worker .ax_stride = input_pixel_stride,
957*4bdc9457SAndroid Build Coastguard Worker .ay_stride = input_width * input_pixel_stride,
958*4bdc9457SAndroid Build Coastguard Worker .cx_stride = stride_width * output_pixel_stride,
959*4bdc9457SAndroid Build Coastguard Worker .cy_stride = stride_height * output_width * output_pixel_stride,
960*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
961*4bdc9457SAndroid Build Coastguard Worker .ga_stride = group_input_channels << log2_input_element_size,
962*4bdc9457SAndroid Build Coastguard Worker .gw_stride = w_stride * round_up(group_output_channels, nr),
963*4bdc9457SAndroid Build Coastguard Worker .gc_stride = group_output_channels << log2_output_element_size,
964*4bdc9457SAndroid Build Coastguard Worker .ba_stride = input_height * input_width * input_pixel_stride,
965*4bdc9457SAndroid Build Coastguard Worker .bc_stride = output_size * output_pixel_stride,
966*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
967*4bdc9457SAndroid Build Coastguard Worker .ukernel = deconvolution_op->ukernel.igemm.gemm_cases[mr - 1],
968*4bdc9457SAndroid Build Coastguard Worker };
969*4bdc9457SAndroid Build Coastguard Worker memcpy(&deconvolution_op->context.subgemm.params, params, params_size);
970*4bdc9457SAndroid Build Coastguard Worker } else {
971*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->context.subconv = (struct subconv_context) {
972*4bdc9457SAndroid Build Coastguard Worker .subconvolution_params = deconvolution_op->subconvolution_buffer,
973*4bdc9457SAndroid Build Coastguard Worker .kc = group_input_channels << log2_input_element_size,
974*4bdc9457SAndroid Build Coastguard Worker .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
975*4bdc9457SAndroid Build Coastguard Worker .zero = deconvolution_op->zero_buffer,
976*4bdc9457SAndroid Build Coastguard Worker .cx_stride = stride_width * output_pixel_stride,
977*4bdc9457SAndroid Build Coastguard Worker .cy_stride = stride_height * output_width * output_pixel_stride,
978*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
979*4bdc9457SAndroid Build Coastguard Worker .ga_stride = group_input_channels << log2_input_element_size,
980*4bdc9457SAndroid Build Coastguard Worker .gw_stride = w_stride * round_up(group_output_channels, nr),
981*4bdc9457SAndroid Build Coastguard Worker .gc_stride = group_output_channels << log2_output_element_size,
982*4bdc9457SAndroid Build Coastguard Worker .ba_stride = input_height * input_width * input_pixel_stride,
983*4bdc9457SAndroid Build Coastguard Worker .bc_stride = output_size * output_pixel_stride,
984*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
985*4bdc9457SAndroid Build Coastguard Worker .ukernel = deconvolution_op->ukernel.igemm.igemm_cases[mr - 1],
986*4bdc9457SAndroid Build Coastguard Worker };
987*4bdc9457SAndroid Build Coastguard Worker memcpy(&deconvolution_op->context.subconv.params, params, params_size);
988*4bdc9457SAndroid Build Coastguard Worker }
989*4bdc9457SAndroid Build Coastguard Worker
990*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
991*4bdc9457SAndroid Build Coastguard Worker const size_t nc = nr;
992*4bdc9457SAndroid Build Coastguard Worker #else
993*4bdc9457SAndroid Build Coastguard Worker size_t nc = group_output_channels;
994*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
995*4bdc9457SAndroid Build Coastguard Worker const size_t num_other_tiles = groups * stride_height * stride_width *
996*4bdc9457SAndroid Build Coastguard Worker output_height_positions * divide_round_up(output_width_positions, mr);
997*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
998*4bdc9457SAndroid Build Coastguard Worker const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
999*4bdc9457SAndroid Build Coastguard Worker if (max_nc < nc) {
1000*4bdc9457SAndroid Build Coastguard Worker nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
1001*4bdc9457SAndroid Build Coastguard Worker }
1002*4bdc9457SAndroid Build Coastguard Worker }
1003*4bdc9457SAndroid Build Coastguard Worker #endif
1004*4bdc9457SAndroid Build Coastguard Worker
1005*4bdc9457SAndroid Build Coastguard Worker if (groups == 1) {
1006*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_5d_tile_2d;
1007*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_5d_tile_2d = use_gemm ?
1008*4bdc9457SAndroid Build Coastguard Worker (pthreadpool_task_5d_tile_2d_t) xnn_compute_subgemm2d : (pthreadpool_task_5d_tile_2d_t) xnn_compute_subconv2d;
1009*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = batch_size;
1010*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = stride_height * stride_width;
1011*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[2] = output_height_positions;
1012*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[3] = output_width_positions;
1013*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[4] = group_output_channels;
1014*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[0] = mr;
1015*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[1] = nc;
1016*4bdc9457SAndroid Build Coastguard Worker } else {
1017*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.type = xnn_parallelization_type_6d_tile_2d;
1018*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.task_6d_tile_2d = use_gemm ?
1019*4bdc9457SAndroid Build Coastguard Worker (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subgemm2d : (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subconv2d;
1020*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[0] = batch_size;
1021*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[1] = groups;
1022*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[2] = stride_height * stride_width;
1023*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[3] = output_height_positions;
1024*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[4] = output_width_positions;
1025*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.range[5] = group_output_channels;
1026*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[0] = mr;
1027*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->compute.tile[1] = nc;
1028*4bdc9457SAndroid Build Coastguard Worker }
1029*4bdc9457SAndroid Build Coastguard Worker
1030*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->state = xnn_run_state_ready;
1031*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1032*4bdc9457SAndroid Build Coastguard Worker }
1033*4bdc9457SAndroid Build Coastguard Worker
setup_deconvolution2d_nhwc(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)1034*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_deconvolution2d_nhwc(
1035*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
1036*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1037*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1038*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1039*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height,
1040*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width,
1041*4bdc9457SAndroid Build Coastguard Worker const void* input,
1042*4bdc9457SAndroid Build Coastguard Worker void* output,
1043*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
1044*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
1045*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
1046*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
1047*4bdc9457SAndroid Build Coastguard Worker const void* params,
1048*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
1049*4bdc9457SAndroid Build Coastguard Worker size_t num_threads)
1050*4bdc9457SAndroid Build Coastguard Worker {
1051*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->state = xnn_run_state_invalid;
1052*4bdc9457SAndroid Build Coastguard Worker
1053*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1054*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
1055*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1056*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
1057*4bdc9457SAndroid Build Coastguard Worker }
1058*4bdc9457SAndroid Build Coastguard Worker
1059*4bdc9457SAndroid Build Coastguard Worker if (input_width == 0 || input_height == 0) {
1060*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1061*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with %zux%zu input: input dimensions must be non-zero",
1062*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type), input_width, input_height);
1063*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1064*4bdc9457SAndroid Build Coastguard Worker }
1065*4bdc9457SAndroid Build Coastguard Worker
1066*4bdc9457SAndroid Build Coastguard Worker if (adjustment_height >= deconvolution_op->stride_height) {
1067*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1068*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with %" PRIu32 " height adjustment: "
1069*4bdc9457SAndroid Build Coastguard Worker "height adjustment must be smaller than height stride (%" PRIu32 ")",
1070*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type), adjustment_height, deconvolution_op->stride_height);
1071*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1072*4bdc9457SAndroid Build Coastguard Worker }
1073*4bdc9457SAndroid Build Coastguard Worker
1074*4bdc9457SAndroid Build Coastguard Worker if (adjustment_width >= deconvolution_op->stride_width) {
1075*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1076*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with %" PRIu32 " width adjustment: "
1077*4bdc9457SAndroid Build Coastguard Worker "width adjustment must be smaller than width stride (%" PRIu32 ")",
1078*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type), adjustment_width, deconvolution_op->stride_width);
1079*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1080*4bdc9457SAndroid Build Coastguard Worker }
1081*4bdc9457SAndroid Build Coastguard Worker
1082*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 0) {
1083*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->state = xnn_run_state_skip;
1084*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1085*4bdc9457SAndroid Build Coastguard Worker }
1086*4bdc9457SAndroid Build Coastguard Worker
1087*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->weights_cache != NULL && !xnn_weights_cache_is_finalized(deconvolution_op->weights_cache)) {
1088*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: weights cache is not finalized",
1089*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1090*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_state;
1091*4bdc9457SAndroid Build Coastguard Worker }
1092*4bdc9457SAndroid Build Coastguard Worker
1093*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->batch_size = batch_size;
1094*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->input_height = input_height;
1095*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->input_width = input_width;
1096*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->input = input;
1097*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output = output;
1098*4bdc9457SAndroid Build Coastguard Worker
1099*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output_height = xnn_compute_deconvolution_output_dimension(
1100*4bdc9457SAndroid Build Coastguard Worker input_height, deconvolution_op->padding_top + deconvolution_op->padding_bottom,
1101*4bdc9457SAndroid Build Coastguard Worker adjustment_height, deconvolution_op->kernel_height, deconvolution_op->dilation_height, deconvolution_op->stride_height);
1102*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output_width = deconvolution_op->output_width = xnn_compute_deconvolution_output_dimension(
1103*4bdc9457SAndroid Build Coastguard Worker input_width, deconvolution_op->padding_left + deconvolution_op->padding_right,
1104*4bdc9457SAndroid Build Coastguard Worker adjustment_width, deconvolution_op->kernel_width, deconvolution_op->dilation_width, deconvolution_op->stride_width);
1105*4bdc9457SAndroid Build Coastguard Worker
1106*4bdc9457SAndroid Build Coastguard Worker switch (deconvolution_op->ukernel.type) {
1107*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_igemm:
1108*4bdc9457SAndroid Build Coastguard Worker return setup_conv_path(
1109*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1110*4bdc9457SAndroid Build Coastguard Worker batch_size,
1111*4bdc9457SAndroid Build Coastguard Worker input_height, input_width, input,
1112*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output_height, deconvolution_op->output_width, output,
1113*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
1114*4bdc9457SAndroid Build Coastguard Worker params, params_size, num_threads);
1115*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_subconv2d:
1116*4bdc9457SAndroid Build Coastguard Worker {
1117*4bdc9457SAndroid Build Coastguard Worker const size_t mr = deconvolution_op->ukernel.igemm.mr;
1118*4bdc9457SAndroid Build Coastguard Worker const bool no_padding = (deconvolution_op->padding_top | deconvolution_op->padding_right | deconvolution_op->padding_bottom | deconvolution_op->padding_left) == 0;
1119*4bdc9457SAndroid Build Coastguard Worker const bool no_adjustment = (adjustment_height | adjustment_width) == 0;
1120*4bdc9457SAndroid Build Coastguard Worker const bool use_gemm = no_padding && no_adjustment &&
1121*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->kernel_height == deconvolution_op->stride_height &&
1122*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->kernel_width == deconvolution_op->stride_width &&
1123*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->ukernel.igemm.gemm_cases[mr - 1].function[XNN_UARCH_DEFAULT] != NULL;
1124*4bdc9457SAndroid Build Coastguard Worker return setup_subconv2d_path(
1125*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1126*4bdc9457SAndroid Build Coastguard Worker batch_size,
1127*4bdc9457SAndroid Build Coastguard Worker input_height, input_width, input,
1128*4bdc9457SAndroid Build Coastguard Worker deconvolution_op->output_height, deconvolution_op->output_width, output,
1129*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
1130*4bdc9457SAndroid Build Coastguard Worker params, params_size, num_threads, use_gemm);
1131*4bdc9457SAndroid Build Coastguard Worker }
1132*4bdc9457SAndroid Build Coastguard Worker default:
1133*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
1134*4bdc9457SAndroid Build Coastguard Worker }
1135*4bdc9457SAndroid Build Coastguard Worker }
1136*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_deconvolution2d_nhwc_qs8(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const int8_t * input,int8_t * output,pthreadpool_t threadpool)1137*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8(
1138*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
1139*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1140*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1141*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1142*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height,
1143*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width,
1144*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
1145*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
1146*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1147*4bdc9457SAndroid Build Coastguard Worker {
1148*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_qs8) {
1149*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1150*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8),
1151*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1152*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1153*4bdc9457SAndroid Build Coastguard Worker }
1154*4bdc9457SAndroid Build Coastguard Worker
1155*4bdc9457SAndroid Build Coastguard Worker return setup_deconvolution2d_nhwc(
1156*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1157*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1158*4bdc9457SAndroid Build Coastguard Worker adjustment_height, adjustment_width,
1159*4bdc9457SAndroid Build Coastguard Worker input, output,
1160*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
1161*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
1162*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
1163*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
1164*4bdc9457SAndroid Build Coastguard Worker &deconvolution_op->params.qs8_conv_minmax, sizeof(deconvolution_op->params.qs8_conv_minmax),
1165*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1166*4bdc9457SAndroid Build Coastguard Worker }
1167*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_deconvolution2d_nhwc_qu8(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)1168*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8(
1169*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
1170*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1171*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1172*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1173*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height,
1174*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width,
1175*4bdc9457SAndroid Build Coastguard Worker const uint8_t* input,
1176*4bdc9457SAndroid Build Coastguard Worker uint8_t* output,
1177*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1178*4bdc9457SAndroid Build Coastguard Worker {
1179*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_qu8) {
1180*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1181*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8),
1182*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1183*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1184*4bdc9457SAndroid Build Coastguard Worker }
1185*4bdc9457SAndroid Build Coastguard Worker
1186*4bdc9457SAndroid Build Coastguard Worker return setup_deconvolution2d_nhwc(
1187*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1188*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1189*4bdc9457SAndroid Build Coastguard Worker adjustment_height, adjustment_width,
1190*4bdc9457SAndroid Build Coastguard Worker input, output,
1191*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
1192*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
1193*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
1194*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
1195*4bdc9457SAndroid Build Coastguard Worker &deconvolution_op->params.qu8_conv_minmax, sizeof(deconvolution_op->params.qu8_conv_minmax),
1196*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1197*4bdc9457SAndroid Build Coastguard Worker }
1198*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_deconvolution2d_nhwc_f16(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const void * input,void * output,pthreadpool_t threadpool)1199*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_deconvolution2d_nhwc_f16(
1200*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
1201*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1202*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1203*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1204*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height,
1205*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width,
1206*4bdc9457SAndroid Build Coastguard Worker const void* input,
1207*4bdc9457SAndroid Build Coastguard Worker void* output,
1208*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1209*4bdc9457SAndroid Build Coastguard Worker {
1210*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_f16) {
1211*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1212*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16),
1213*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1214*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1215*4bdc9457SAndroid Build Coastguard Worker }
1216*4bdc9457SAndroid Build Coastguard Worker
1217*4bdc9457SAndroid Build Coastguard Worker return setup_deconvolution2d_nhwc(
1218*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1219*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1220*4bdc9457SAndroid Build Coastguard Worker adjustment_height, adjustment_width,
1221*4bdc9457SAndroid Build Coastguard Worker input, output,
1222*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
1223*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
1224*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(bias element) */,
1225*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
1226*4bdc9457SAndroid Build Coastguard Worker &deconvolution_op->params.f16_minmax, sizeof(deconvolution_op->params.f16_minmax),
1227*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1228*4bdc9457SAndroid Build Coastguard Worker }
1229*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_deconvolution2d_nhwc_f32(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const float * input,float * output,pthreadpool_t threadpool)1230*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
1231*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op,
1232*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1233*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1234*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1235*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height,
1236*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width,
1237*4bdc9457SAndroid Build Coastguard Worker const float* input,
1238*4bdc9457SAndroid Build Coastguard Worker float* output,
1239*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1240*4bdc9457SAndroid Build Coastguard Worker {
1241*4bdc9457SAndroid Build Coastguard Worker if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_f32) {
1242*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1243*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32),
1244*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(deconvolution_op->type));
1245*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1246*4bdc9457SAndroid Build Coastguard Worker }
1247*4bdc9457SAndroid Build Coastguard Worker
1248*4bdc9457SAndroid Build Coastguard Worker return setup_deconvolution2d_nhwc(
1249*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1250*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1251*4bdc9457SAndroid Build Coastguard Worker adjustment_height, adjustment_width,
1252*4bdc9457SAndroid Build Coastguard Worker input, output,
1253*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
1254*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
1255*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
1256*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
1257*4bdc9457SAndroid Build Coastguard Worker &deconvolution_op->params.f32_minmax, sizeof(deconvolution_op->params.f32_minmax),
1258*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1259*4bdc9457SAndroid Build Coastguard Worker }
1260