xref: /aosp_15_r20/external/XNNPACK/src/subgraph/deconvolution-2d.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/common.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_deconvolution_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_deconvolution_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs >= 2);
28   assert(node->num_inputs <= 3);
29   const bool use_bias = node->num_inputs >= 3;
30 
31   const uint32_t input_id = node->inputs[0];
32   assert(input_id != XNN_INVALID_VALUE_ID);
33   assert(input_id < num_values);
34   const uint32_t filter_id = node->inputs[1];
35   assert(filter_id != XNN_INVALID_VALUE_ID);
36   assert(filter_id < num_values);
37 
38   const void* bias_data = NULL;
39   if (use_bias) {
40     const uint32_t bias_id = node->inputs[2];
41     assert(bias_id != XNN_INVALID_VALUE_ID);
42     assert(bias_id < num_values);
43 
44     bias_data = values[bias_id].data;
45     assert(bias_data != NULL);
46   }
47 
48   assert(node->num_outputs == 1);
49   const uint32_t output_id = node->outputs[0];
50   assert(output_id != XNN_INVALID_VALUE_ID);
51   assert(output_id < num_values);
52 
53   const void* filter_data = values[filter_id].data;
54   assert(filter_data != NULL);
55 
56   enum xnn_status status = xnn_status_uninitialized;
57   switch (node->compute_type) {
58 #ifndef XNN_NO_F16_OPERATORS
59     case xnn_compute_type_fp16:
60       status = xnn_create_deconvolution2d_nhwc_f16(
61           node->params.deconvolution_2d.padding_top,
62           node->params.deconvolution_2d.padding_right,
63           node->params.deconvolution_2d.padding_bottom,
64           node->params.deconvolution_2d.padding_left,
65           node->params.deconvolution_2d.kernel_height,
66           node->params.deconvolution_2d.kernel_width,
67           node->params.deconvolution_2d.upsampling_height,
68           node->params.deconvolution_2d.upsampling_width,
69           node->params.deconvolution_2d.dilation_height,
70           node->params.deconvolution_2d.dilation_width,
71           node->params.deconvolution_2d.groups,
72           node->params.deconvolution_2d.group_input_channels,
73           node->params.deconvolution_2d.group_output_channels,
74           node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */,
75           node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */,
76           filter_data,
77           bias_data,
78           node->activation.output_min,
79           node->activation.output_max,
80           node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
81           caches,
82           &opdata->operator_objects[0]);
83       break;
84 #endif  // !defined(XNN_NO_F16_OPERATORS)
85     case xnn_compute_type_fp32:
86       status = xnn_create_deconvolution2d_nhwc_f32(
87           node->params.deconvolution_2d.padding_top,
88           node->params.deconvolution_2d.padding_right,
89           node->params.deconvolution_2d.padding_bottom,
90           node->params.deconvolution_2d.padding_left,
91           node->params.deconvolution_2d.kernel_height,
92           node->params.deconvolution_2d.kernel_width,
93           node->params.deconvolution_2d.upsampling_height,
94           node->params.deconvolution_2d.upsampling_width,
95           node->params.deconvolution_2d.dilation_height,
96           node->params.deconvolution_2d.dilation_width,
97           node->params.deconvolution_2d.groups,
98           node->params.deconvolution_2d.group_input_channels,
99           node->params.deconvolution_2d.group_output_channels,
100           node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */,
101           node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */,
102           filter_data,
103           bias_data,
104           node->activation.output_min,
105           node->activation.output_max,
106           node->flags,
107           caches,
108           &opdata->operator_objects[0]);
109       break;
110 #ifndef XNN_NO_QS8_OPERATORS
111     case xnn_compute_type_qs8:
112     {
113       const float output_scale = values[output_id].quantization.scale;
114       const int32_t output_zero_point = values[output_id].quantization.zero_point;
115       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
116       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
117       status = xnn_create_deconvolution2d_nhwc_qs8(
118           node->params.deconvolution_2d.padding_top,
119           node->params.deconvolution_2d.padding_right,
120           node->params.deconvolution_2d.padding_bottom,
121           node->params.deconvolution_2d.padding_left,
122           node->params.deconvolution_2d.kernel_height,
123           node->params.deconvolution_2d.kernel_width,
124           node->params.deconvolution_2d.upsampling_height,
125           node->params.deconvolution_2d.upsampling_width,
126           node->params.deconvolution_2d.dilation_height,
127           node->params.deconvolution_2d.dilation_width,
128           node->params.deconvolution_2d.groups,
129           node->params.deconvolution_2d.group_input_channels,
130           node->params.deconvolution_2d.group_output_channels,
131           node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */,
132           node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */,
133           (int8_t) values[input_id].quantization.zero_point,
134           values[input_id].quantization.scale,
135           values[filter_id].quantization.scale,
136           filter_data,
137           bias_data,
138           output_zero_point,
139           output_scale,
140           output_min,
141           output_max,
142           node->flags,
143           caches,
144           &opdata->operator_objects[0]);
145       break;
146     }
147 #endif  // !defined(XNN_NO_QS8_OPERATORS)
148 #ifndef XNN_NO_QU8_OPERATORS
149     case xnn_compute_type_qu8:
150     {
151       const float output_scale = values[output_id].quantization.scale;
152       const int32_t output_zero_point = values[output_id].quantization.zero_point;
153       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
154       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
155       status = xnn_create_deconvolution2d_nhwc_qu8(
156           node->params.deconvolution_2d.padding_top,
157           node->params.deconvolution_2d.padding_right,
158           node->params.deconvolution_2d.padding_bottom,
159           node->params.deconvolution_2d.padding_left,
160           node->params.deconvolution_2d.kernel_height,
161           node->params.deconvolution_2d.kernel_width,
162           node->params.deconvolution_2d.upsampling_height,
163           node->params.deconvolution_2d.upsampling_width,
164           node->params.deconvolution_2d.dilation_height,
165           node->params.deconvolution_2d.dilation_width,
166           node->params.deconvolution_2d.groups,
167           node->params.deconvolution_2d.group_input_channels,
168           node->params.deconvolution_2d.group_output_channels,
169           node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */,
170           node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */,
171           (uint8_t) values[input_id].quantization.zero_point,
172           values[input_id].quantization.scale,
173           (uint8_t) values[filter_id].quantization.zero_point,
174           values[filter_id].quantization.scale,
175           filter_data,
176           bias_data,
177           output_zero_point,
178           output_scale,
179           output_min,
180           output_max,
181           node->flags,
182           caches,
183           &opdata->operator_objects[0]);
184       break;
185     }
186 #endif  // !defined(XNN_NO_QU8_OPERATORS)
187     default:
188       XNN_UNREACHABLE;
189   }
190   if (status == xnn_status_success) {
191     opdata->batch_size = values[input_id].shape.dim[0];
192     opdata->input_height = values[input_id].shape.dim[1];
193     opdata->input_width = values[input_id].shape.dim[2];
194     opdata->adjustment_height = node->params.deconvolution_2d.adjustment_height;
195     opdata->adjustment_width = node->params.deconvolution_2d.adjustment_width;
196     opdata->inputs[0] = input_id;
197     opdata->outputs[0] = output_id;
198   }
199   return status;
200 }
201 
setup_deconvolution_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)202 static enum xnn_status setup_deconvolution_operator(
203   const struct xnn_operator_data* opdata,
204   const struct xnn_blob* blobs,
205   size_t num_blobs,
206   pthreadpool_t threadpool)
207 {
208   const uint32_t input_id = opdata->inputs[0];
209   assert(input_id != XNN_INVALID_VALUE_ID);
210   assert(input_id < num_blobs);
211 
212   const uint32_t output_id = opdata->outputs[0];
213   assert(output_id != XNN_INVALID_VALUE_ID);
214   assert(output_id < num_blobs);
215 
216   const struct xnn_blob* input_blob = blobs + input_id;
217   const void* input_data = input_blob->data;
218   assert(input_data != NULL);
219 
220   const struct xnn_blob* output_blob = blobs + output_id;
221   void* output_data = output_blob->data;
222   assert(output_data != NULL);
223 
224   switch (opdata->operator_objects[0]->type) {
225 #ifndef XNN_NO_F16_OPERATORS
226     case xnn_operator_type_deconvolution_nhwc_f16:
227       return xnn_setup_deconvolution2d_nhwc_f16(
228           opdata->operator_objects[0],
229           opdata->batch_size,
230           opdata->input_height,
231           opdata->input_width,
232           opdata->adjustment_height,
233           opdata->adjustment_width,
234           input_data,
235           output_data,
236           threadpool);
237       break;
238 #endif  // !defined(XNN_NO_F16_OPERATORS)
239     case xnn_operator_type_deconvolution_nhwc_f32:
240       return xnn_setup_deconvolution2d_nhwc_f32(
241           opdata->operator_objects[0],
242           opdata->batch_size,
243           opdata->input_height,
244           opdata->input_width,
245           opdata->adjustment_height,
246           opdata->adjustment_width,
247           input_data,
248           output_data,
249           threadpool);
250       break;
251 #ifndef XNN_NO_QS8_OPERATORS
252     case xnn_operator_type_deconvolution_nhwc_qs8:
253       return xnn_setup_deconvolution2d_nhwc_qs8(
254           opdata->operator_objects[0],
255           opdata->batch_size,
256           opdata->input_height,
257           opdata->input_width,
258           opdata->adjustment_height,
259           opdata->adjustment_width,
260           input_data,
261           output_data,
262           threadpool);
263       break;
264 #endif  // !defined(XNN_NO_QS8_OPERATORS)
265 #ifndef XNN_NO_QU8_OPERATORS
266     case xnn_operator_type_deconvolution_nhwc_qu8:
267       return xnn_setup_deconvolution2d_nhwc_qu8(
268           opdata->operator_objects[0],
269           opdata->batch_size,
270           opdata->input_height,
271           opdata->input_width,
272           opdata->adjustment_height,
273           opdata->adjustment_width,
274           input_data,
275           output_data,
276           threadpool);
277       break;
278 #endif  // !defined(XNN_NO_QU8_OPERATORS)
279     default:
280       XNN_UNREACHABLE;
281   }
282 }
283 
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)284 static inline enum xnn_compute_type validate_datatypes_with_bias(
285   enum xnn_datatype input_datatype,
286   enum xnn_datatype filter_datatype,
287   enum xnn_datatype bias_datatype,
288   enum xnn_datatype output_datatype)
289 {
290   switch (filter_datatype) {
291     case xnn_datatype_fp32:
292       if (input_datatype == xnn_datatype_fp32 &&
293           bias_datatype == xnn_datatype_fp32 &&
294           output_datatype == xnn_datatype_fp32)
295       {
296         return xnn_compute_type_fp32;
297       }
298       break;
299 #ifndef XNN_NO_QS8_OPERATORS
300     case xnn_datatype_qint8:
301       if (input_datatype == xnn_datatype_qint8 &&
302           bias_datatype == xnn_datatype_qint32 &&
303           output_datatype == xnn_datatype_qint8)
304       {
305         return xnn_compute_type_qs8;
306       }
307       break;
308 #endif  // !defined(XNN_NO_QS8_OPERATORS)
309 #ifndef XNN_NO_QU8_OPERATORS
310     case xnn_datatype_quint8:
311       if (input_datatype == xnn_datatype_quint8 &&
312           bias_datatype == xnn_datatype_qint32 &&
313           output_datatype == xnn_datatype_quint8)
314       {
315         return xnn_compute_type_qu8;
316       }
317       break;
318 #endif  // !defined(XNN_NO_QU8_OPERATORS)
319     default:
320       XNN_UNREACHABLE;
321   }
322   return xnn_compute_type_invalid;
323 }
324 
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)325 static inline enum xnn_compute_type validate_datatypes_without_bias(
326   enum xnn_datatype input_datatype,
327   enum xnn_datatype filter_datatype,
328   enum xnn_datatype output_datatype)
329 {
330   switch (filter_datatype) {
331     case xnn_datatype_fp32:
332       if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
333         return xnn_compute_type_fp32;
334       }
335       break;
336 #ifndef XNN_NO_QS8_OPERATORS
337     case xnn_datatype_qint8:
338       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
339         return xnn_compute_type_qs8;
340       }
341       break;
342 #endif  // !defined(XNN_NO_QS8_OPERATORS)
343 #ifndef XNN_NO_QU8_OPERATORS
344     case xnn_datatype_quint8:
345       if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
346         return xnn_compute_type_qu8;
347       }
348       break;
349 #endif  // !defined(XNN_NO_QU8_OPERATORS)
350     default:
351       XNN_UNREACHABLE;
352   }
353   return xnn_compute_type_invalid;
354 }
355 
xnn_define_deconvolution_2d(xnn_subgraph_t subgraph,uint32_t padding_top,uint32_t padding_right,uint32_t padding_bottom,uint32_t padding_left,uint32_t adjustment_height,uint32_t adjustment_width,uint32_t kernel_height,uint32_t kernel_width,uint32_t upsampling_height,uint32_t upsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)356 enum xnn_status xnn_define_deconvolution_2d(
357   xnn_subgraph_t subgraph,
358   uint32_t padding_top,
359   uint32_t padding_right,
360   uint32_t padding_bottom,
361   uint32_t padding_left,
362   uint32_t adjustment_height,
363   uint32_t adjustment_width,
364   uint32_t kernel_height,
365   uint32_t kernel_width,
366   uint32_t upsampling_height,
367   uint32_t upsampling_width,
368   uint32_t dilation_height,
369   uint32_t dilation_width,
370   uint32_t groups,
371   size_t group_input_channels,
372   size_t group_output_channels,
373   float output_min,
374   float output_max,
375   uint32_t input_id,
376   uint32_t filter_id,
377   uint32_t bias_id,
378   uint32_t output_id,
379   uint32_t flags)
380 {
381   enum xnn_status status;
382   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_deconvolution_2d)) != xnn_status_success) {
383     return status;
384   }
385 
386   if (kernel_width == 0 || kernel_height == 0) {
387     xnn_log_error(
388       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
389       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), kernel_width, kernel_height);
390     return xnn_status_invalid_parameter;
391   }
392 
393   if (upsampling_width == 0 || upsampling_height == 0) {
394     xnn_log_error(
395       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " upsampling: upsampling dimensions must be non-zero",
396       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), upsampling_width, upsampling_height);
397     return xnn_status_invalid_parameter;
398   }
399 
400   if (dilation_width == 0 || dilation_height == 0) {
401     xnn_log_error(
402       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
403       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), dilation_width, dilation_height);
404     return xnn_status_invalid_parameter;
405   }
406 
407   if (groups == 0) {
408     xnn_log_error(
409       "failed to define %s operator with %" PRIu32 " groups: number of groups must be non-zero",
410       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), groups);
411     return xnn_status_invalid_parameter;
412   }
413 
414   if (group_input_channels == 0) {
415     xnn_log_error(
416       "failed to define %s operator with %zu input channels per group: number of channels must be non-zero",
417       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), group_input_channels);
418     return xnn_status_invalid_parameter;
419   }
420 
421   if (group_output_channels == 0) {
422     xnn_log_error(
423       "failed to define %s operator with %zu output channels per group: number of channels must be non-zero",
424       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), group_output_channels);
425     return xnn_status_invalid_parameter;
426   }
427 
428   status = xnn_subgraph_check_output_min_max(xnn_node_type_deconvolution_2d, output_min, output_max);
429   if (status != xnn_status_success) {
430     return status;
431   }
432 
433   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_deconvolution_2d, input_id, subgraph->num_values)) !=
434       xnn_status_success) {
435     return status;
436   }
437 
438   const struct xnn_value* input_value = &subgraph->values[input_id];
439   status = xnn_subgraph_check_input_type_dense(xnn_node_type_deconvolution_2d, input_id, input_value);
440   if (status != xnn_status_success) {
441     return status;
442   }
443 
444   switch (input_value->datatype) {
445     case xnn_datatype_fp32:
446 #ifndef XNN_NO_QS8_OPERATORS
447     case xnn_datatype_qint8:
448 #endif  // !defined(XNN_NO_QS8_OPERATORS)
449 #ifndef XNN_NO_QU8_OPERATORS
450     case xnn_datatype_quint8:
451 #endif  // !defined(XNN_NO_QU8_OPERATORS)
452       break;
453     default:
454       xnn_log_error(
455         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
456         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), input_id,
457         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
458       return xnn_status_invalid_parameter;
459   }
460 
461   if (filter_id >= subgraph->num_values) {
462     xnn_log_error(
463       "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
464       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), filter_id);
465     return xnn_status_invalid_parameter;
466   }
467 
468   const struct xnn_value* filter_value = &subgraph->values[filter_id];
469   if (filter_value->type != xnn_value_type_dense_tensor) {
470     xnn_log_error(
471       "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
472       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), filter_id, filter_value->type);
473     return xnn_status_invalid_parameter;
474   }
475 
476   if (filter_value->data == NULL) {
477     xnn_log_error(
478       "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
479       xnn_node_type_to_string(xnn_node_type_deconvolution_2d), filter_id);
480     return xnn_status_invalid_parameter;
481   }
482 
483   switch (filter_value->datatype) {
484     case xnn_datatype_fp32:
485       break;
486 #ifndef XNN_NO_QS8_OPERATORS
487     case xnn_datatype_qint8:
488       if (filter_value->quantization.zero_point != 0) {
489         xnn_log_error(
490           "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
491           xnn_node_type_to_string(xnn_node_type_deconvolution_2d), filter_id,
492           filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
493       }
494       break;
495 #endif  // !defined(XNN_NO_QS8_OPERATORS)
496 #ifndef XNN_NO_QU8_OPERATORS
497     case xnn_datatype_quint8:
498       break;
499 #endif  // !defined(XNN_NO_QU8_OPERATORS)
500     default:
501       xnn_log_error(
502         "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
503         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), filter_id,
504         xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
505       return xnn_status_invalid_parameter;
506   }
507 
508   const struct xnn_value* bias_value = NULL;
509 
510   if (bias_id != XNN_INVALID_VALUE_ID) {
511     if (bias_id >= subgraph->num_values) {
512       xnn_log_error(
513         "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
514         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), bias_id);
515       return xnn_status_invalid_parameter;
516     }
517 
518     bias_value = &subgraph->values[bias_id];
519     if (bias_value->type != xnn_value_type_dense_tensor) {
520       xnn_log_error(
521         "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
522         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), bias_id, bias_value->type);
523       return xnn_status_invalid_parameter;
524     }
525 
526     if (bias_value->data == NULL) {
527       xnn_log_error(
528         "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
529         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), bias_id);
530       return xnn_status_invalid_parameter;
531     }
532 
533     switch (bias_value->datatype) {
534       case xnn_datatype_fp32:
535 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
536       case xnn_datatype_qint32:
537 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
538         break;
539       default:
540         xnn_log_error(
541           "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
542           xnn_node_type_to_string(xnn_node_type_deconvolution_2d), bias_id,
543           xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
544         return xnn_status_invalid_parameter;
545     }
546   }
547 
548   status = xnn_subgraph_check_output_node_id(xnn_node_type_deconvolution_2d, output_id, subgraph->num_values);
549   if (status != xnn_status_success) {
550     return status;
551   }
552 
553   const struct xnn_value* output_value = &subgraph->values[output_id];
554   status = xnn_subgraph_check_output_type_dense(xnn_node_type_deconvolution_2d, output_id, output_value);
555   if (status != xnn_status_success) {
556     return status;
557   }
558 
559   switch (output_value->datatype) {
560     case xnn_datatype_fp32:
561 #ifndef XNN_NO_QS8_OPERATORS
562     case xnn_datatype_qint8:
563 #endif  // !defined(XNN_NO_QS8_OPERATORS)
564 #ifndef XNN_NO_QU8_OPERATORS
565     case xnn_datatype_quint8:
566 #endif  // !defined(XNN_NO_QU8_OPERATORS)
567       break;
568     default:
569       xnn_log_error(
570         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
571         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), output_id,
572         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
573       return xnn_status_invalid_parameter;
574   }
575 
576   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
577   if (bias_value != NULL) {
578     compute_type = validate_datatypes_with_bias(
579       input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
580     if (compute_type == xnn_compute_type_invalid) {
581       xnn_log_error(
582         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
583         ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
584         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), input_id, filter_id, bias_id, output_id,
585         xnn_datatype_to_string(input_value->datatype),
586         xnn_datatype_to_string(filter_value->datatype),
587         xnn_datatype_to_string(bias_value->datatype),
588         xnn_datatype_to_string(output_value->datatype));
589       return xnn_status_invalid_parameter;
590     }
591   } else {
592     compute_type = validate_datatypes_without_bias(
593       input_value->datatype, filter_value->datatype, output_value->datatype);
594     if (compute_type == xnn_compute_type_invalid) {
595       xnn_log_error(
596         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
597         ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
598         xnn_node_type_to_string(xnn_node_type_deconvolution_2d), input_id, filter_id, output_id,
599         xnn_datatype_to_string(input_value->datatype),
600         xnn_datatype_to_string(filter_value->datatype),
601         xnn_datatype_to_string(output_value->datatype));
602       return xnn_status_invalid_parameter;
603     }
604   }
605 
606   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
607   if (node == NULL) {
608     return xnn_status_out_of_memory;
609   }
610 
611   node->type = xnn_node_type_deconvolution_2d;
612   node->compute_type = compute_type;
613   node->params.deconvolution_2d.padding_top = padding_top;
614   node->params.deconvolution_2d.padding_right = padding_right;
615   node->params.deconvolution_2d.padding_bottom = padding_bottom;
616   node->params.deconvolution_2d.padding_left = padding_left;
617   node->params.deconvolution_2d.kernel_height = kernel_height;
618   node->params.deconvolution_2d.kernel_width = kernel_width;
619   node->params.deconvolution_2d.upsampling_height = upsampling_height;
620   node->params.deconvolution_2d.upsampling_width = upsampling_width;
621   node->params.deconvolution_2d.dilation_height = dilation_height;
622   node->params.deconvolution_2d.dilation_width = dilation_width;
623   node->params.deconvolution_2d.adjustment_height = adjustment_height;
624   node->params.deconvolution_2d.adjustment_width = adjustment_width;
625   node->params.deconvolution_2d.groups = groups;
626   node->params.deconvolution_2d.group_input_channels = group_input_channels;
627   node->params.deconvolution_2d.group_output_channels = group_output_channels;
628   node->activation.output_min = output_min;
629   node->activation.output_max = output_max;
630   node->num_inputs = 2 + (size_t) (bias_value != NULL);
631   node->inputs[0] = input_id;
632   node->inputs[1] = filter_id;
633   node->inputs[2] = bias_id;
634   node->num_outputs = 1;
635   node->outputs[0] = output_id;
636   node->flags = flags;
637 
638   node->create = create_deconvolution_operator;
639   node->setup = setup_deconvolution_operator;
640 
641   return xnn_status_success;
642 };
643