xref: /aosp_15_r20/external/XNNPACK/src/operators/channel-shuffle-nc.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/params.h>
20 
21 
create_channel_shuffle_nc(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,enum xnn_operator_type operator_type,xnn_operator_t * channel_shuffle_op_out)22 static enum xnn_status create_channel_shuffle_nc(
23   size_t groups,
24   size_t group_channels,
25   size_t input_stride,
26   size_t output_stride,
27   uint32_t flags,
28   enum xnn_operator_type operator_type,
29   xnn_operator_t* channel_shuffle_op_out)
30 {
31   xnn_operator_t channel_shuffle_op = NULL;
32   enum xnn_status status = xnn_status_uninitialized;
33 
34   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
35     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
36       xnn_operator_type_to_string(operator_type));
37     goto error;
38   }
39 
40   status = xnn_status_invalid_parameter;
41 
42   if (groups <= 1) {
43     xnn_log_error(
44       "failed to create %s operator with %zu groups: at least two groups required",
45       xnn_operator_type_to_string(operator_type), groups);
46     goto error;
47   }
48 
49   if (group_channels == 0) {
50     xnn_log_error(
51       "failed to create %s operator with %zu group channels: number of group channels must be non-zero",
52       xnn_operator_type_to_string(operator_type), group_channels);
53     goto error;
54   }
55 
56   const size_t channels = groups * group_channels;
57   if (input_stride < channels) {
58     xnn_log_error(
59       "failed to create %s operator with input element stride of %zu: "
60       "stride must be at least as large as the number of channels (%zux%zu)",
61       xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels);
62     goto error;
63   }
64 
65   if (output_stride < channels) {
66     xnn_log_error(
67       "failed to create %s operator with output element stride of %zu: "
68       "stride must be at least as large as the number of channels (%zux%zu)",
69       xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels);
70     goto error;
71   }
72 
73   status = xnn_status_out_of_memory;
74 
75   channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
76   if (channel_shuffle_op == NULL) {
77     xnn_log_error(
78       "failed to allocate %zu bytes for %s operator descriptor",
79       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
80     goto error;
81   }
82 
83   channel_shuffle_op->groups = groups;
84   channel_shuffle_op->group_channels = group_channels;
85   channel_shuffle_op->input_pixel_stride = input_stride;
86   channel_shuffle_op->output_pixel_stride = output_stride;
87 
88   channel_shuffle_op->type = operator_type;
89   channel_shuffle_op->flags = flags;
90 
91   channel_shuffle_op->state = xnn_run_state_invalid;
92 
93   *channel_shuffle_op_out = channel_shuffle_op;
94   return xnn_status_success;
95 
96 error:
97   xnn_delete_operator(channel_shuffle_op);
98   return status;
99 }
100 
101 
xnn_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)102 enum xnn_status xnn_create_channel_shuffle_nc_x8(
103     size_t groups,
104     size_t group_channels,
105     size_t input_stride,
106     size_t output_stride,
107     uint32_t flags,
108     xnn_operator_t* channel_shuffle_op_out)
109 {
110   return create_channel_shuffle_nc(
111     groups,
112     group_channels,
113     input_stride,
114     output_stride,
115     flags,
116     xnn_operator_type_channel_shuffle_nc_x8,
117     channel_shuffle_op_out);
118 }
119 
xnn_create_channel_shuffle_nc_x32(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)120 enum xnn_status xnn_create_channel_shuffle_nc_x32(
121     size_t groups,
122     size_t group_channels,
123     size_t input_stride,
124     size_t output_stride,
125     uint32_t flags,
126     xnn_operator_t* channel_shuffle_op_out)
127 {
128   return create_channel_shuffle_nc(
129     groups,
130     group_channels,
131     input_stride,
132     output_stride,
133     flags,
134     xnn_operator_type_channel_shuffle_nc_x32,
135     channel_shuffle_op_out);
136 }
137 
setup_channel_shuffle_nc(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS (1)])138 static enum xnn_status setup_channel_shuffle_nc(
139     xnn_operator_t channel_shuffle_op,
140     size_t batch_size,
141     const void* input,
142     void* output,
143     uint32_t log2_element_size,
144     const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS(1)])
145 {
146   channel_shuffle_op->state = xnn_run_state_invalid;
147 
148   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
149     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
150       xnn_operator_type_to_string(channel_shuffle_op->type));
151     return xnn_status_uninitialized;
152   }
153 
154   if (batch_size == 0) {
155     channel_shuffle_op->state = xnn_run_state_skip;
156     return xnn_status_success;
157   }
158 
159   channel_shuffle_op->batch_size = batch_size;
160   channel_shuffle_op->input = input;
161   channel_shuffle_op->output = output;
162 
163   const size_t groups = channel_shuffle_op->groups;
164   channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
165     .x = input,
166     .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
167     .y = output,
168     .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
169     .n = channel_shuffle_op->group_channels << log2_element_size,
170     .m = groups,
171   };
172   channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
173   channel_shuffle_op->compute.range[0] = batch_size;
174   switch (groups) {
175     case 2:
176       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
177       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
178       break;
179     case 3:
180       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
181       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
182       break;
183     case 4:
184       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
185       channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
186       break;
187     default:
188       channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
189       channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
190       break;
191     case 0:
192     case 1:
193       XNN_UNREACHABLE;
194   }
195   channel_shuffle_op->state = xnn_run_state_ready;
196 
197   return xnn_status_success;
198 }
199 
xnn_setup_channel_shuffle_nc_x8(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)200 enum xnn_status xnn_setup_channel_shuffle_nc_x8(
201     xnn_operator_t channel_shuffle_op,
202     size_t batch_size,
203     const void* input,
204     void* output,
205     pthreadpool_t threadpool)
206 {
207   if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
208     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
209       xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8),
210       xnn_operator_type_to_string(channel_shuffle_op->type));
211     return xnn_status_invalid_parameter;
212   }
213 
214   return setup_channel_shuffle_nc(
215     channel_shuffle_op,
216     batch_size,
217     input,
218     output,
219     0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
220     &xnn_params.x8.zip);
221 }
222 
xnn_setup_channel_shuffle_nc_x32(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)223 enum xnn_status xnn_setup_channel_shuffle_nc_x32(
224     xnn_operator_t channel_shuffle_op,
225     size_t batch_size,
226     const void* input,
227     void* output,
228     pthreadpool_t threadpool)
229 {
230   if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
231     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
232       xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32),
233       xnn_operator_type_to_string(channel_shuffle_op->type));
234     return xnn_status_invalid_parameter;
235   }
236 
237   return setup_channel_shuffle_nc(
238     channel_shuffle_op,
239     batch_size,
240     input,
241     output,
242     2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
243     &xnn_params.x32.zip);
244 }
245