xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/channel-shuffle.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 #include <qnnpack/params.h>
19 
pytorch_qnnp_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,uint32_t flags,pytorch_qnnp_operator_t * channel_shuffle_out)20 enum pytorch_qnnp_status pytorch_qnnp_create_channel_shuffle_nc_x8(
21     size_t groups,
22     size_t group_channels,
23     uint32_t flags,
24     pytorch_qnnp_operator_t* channel_shuffle_out) {
25   pytorch_qnnp_operator_t channel_shuffle_op = NULL;
26   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
27 
28   if (!pytorch_qnnp_params.initialized) {
29     pytorch_qnnp_log_error(
30         "pytorch_qnnp_create_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized");
31     goto error;
32   }
33 
34   status = pytorch_qnnp_status_invalid_parameter;
35 
36   if (groups <= 1) {
37     pytorch_qnnp_log_error(
38         "failed to create channel shuffle operator with %zu groups: "
39         "at least two groups required",
40         groups);
41     goto error;
42   }
43 
44   if (group_channels == 0) {
45     pytorch_qnnp_log_error(
46         "failed to create channel shuffle operator with %zu group channels: "
47         "number of group channels must be non-zero",
48         group_channels);
49     goto error;
50   }
51 
52   status = pytorch_qnnp_status_out_of_memory;
53 
54   channel_shuffle_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
55   if (channel_shuffle_op == NULL) {
56     pytorch_qnnp_log_error(
57         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
58         sizeof(struct pytorch_qnnp_operator));
59     goto error;
60   }
61 
62   channel_shuffle_op->groups = groups;
63   channel_shuffle_op->group_channels = group_channels;
64 
65   channel_shuffle_op->ukernel_type = pytorch_qnnp_ukernel_type_channel_shuffle;
66   channel_shuffle_op->format = pytorch_qnnp_format_quint8;
67 
68   *channel_shuffle_out = channel_shuffle_op;
69   return pytorch_qnnp_status_success;
70 
71 error:
72   pytorch_qnnp_delete_operator(channel_shuffle_op);
73   return status;
74 }
75 
pytorch_qnnp_setup_channel_shuffle_nc_x8(pytorch_qnnp_operator_t channel_shuffle_op,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)76 enum pytorch_qnnp_status pytorch_qnnp_setup_channel_shuffle_nc_x8(
77     pytorch_qnnp_operator_t channel_shuffle_op,
78     size_t batch_size,
79     const uint8_t* input,
80     size_t input_stride,
81     uint8_t* output,
82     size_t output_stride) {
83   if (!pytorch_qnnp_params.initialized) {
84     pytorch_qnnp_log_error(
85         "pytorch_qnnp_setup_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized");
86     return pytorch_qnnp_status_uninitialized;
87   }
88 
89   if (batch_size == 0) {
90     channel_shuffle_op->batch_size = 0;
91     return pytorch_qnnp_status_success;
92   }
93 
94   channel_shuffle_op->batch_size = batch_size;
95   channel_shuffle_op->input = input;
96   channel_shuffle_op->input_pixel_stride = input_stride;
97   channel_shuffle_op->output = output;
98   channel_shuffle_op->output_pixel_stride = output_stride;
99 
100   return pytorch_qnnp_status_success;
101 }
102