1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/params.h>
20
21
create_channel_shuffle_nc(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,enum xnn_operator_type operator_type,xnn_operator_t * channel_shuffle_op_out)22 static enum xnn_status create_channel_shuffle_nc(
23 size_t groups,
24 size_t group_channels,
25 size_t input_stride,
26 size_t output_stride,
27 uint32_t flags,
28 enum xnn_operator_type operator_type,
29 xnn_operator_t* channel_shuffle_op_out)
30 {
31 xnn_operator_t channel_shuffle_op = NULL;
32 enum xnn_status status = xnn_status_uninitialized;
33
34 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
35 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
36 xnn_operator_type_to_string(operator_type));
37 goto error;
38 }
39
40 status = xnn_status_invalid_parameter;
41
42 if (groups <= 1) {
43 xnn_log_error(
44 "failed to create %s operator with %zu groups: at least two groups required",
45 xnn_operator_type_to_string(operator_type), groups);
46 goto error;
47 }
48
49 if (group_channels == 0) {
50 xnn_log_error(
51 "failed to create %s operator with %zu group channels: number of group channels must be non-zero",
52 xnn_operator_type_to_string(operator_type), group_channels);
53 goto error;
54 }
55
56 const size_t channels = groups * group_channels;
57 if (input_stride < channels) {
58 xnn_log_error(
59 "failed to create %s operator with input element stride of %zu: "
60 "stride must be at least as large as the number of channels (%zux%zu)",
61 xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels);
62 goto error;
63 }
64
65 if (output_stride < channels) {
66 xnn_log_error(
67 "failed to create %s operator with output element stride of %zu: "
68 "stride must be at least as large as the number of channels (%zux%zu)",
69 xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels);
70 goto error;
71 }
72
73 status = xnn_status_out_of_memory;
74
75 channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
76 if (channel_shuffle_op == NULL) {
77 xnn_log_error(
78 "failed to allocate %zu bytes for %s operator descriptor",
79 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
80 goto error;
81 }
82
83 channel_shuffle_op->groups = groups;
84 channel_shuffle_op->group_channels = group_channels;
85 channel_shuffle_op->input_pixel_stride = input_stride;
86 channel_shuffle_op->output_pixel_stride = output_stride;
87
88 channel_shuffle_op->type = operator_type;
89 channel_shuffle_op->flags = flags;
90
91 channel_shuffle_op->state = xnn_run_state_invalid;
92
93 *channel_shuffle_op_out = channel_shuffle_op;
94 return xnn_status_success;
95
96 error:
97 xnn_delete_operator(channel_shuffle_op);
98 return status;
99 }
100
101
xnn_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)102 enum xnn_status xnn_create_channel_shuffle_nc_x8(
103 size_t groups,
104 size_t group_channels,
105 size_t input_stride,
106 size_t output_stride,
107 uint32_t flags,
108 xnn_operator_t* channel_shuffle_op_out)
109 {
110 return create_channel_shuffle_nc(
111 groups,
112 group_channels,
113 input_stride,
114 output_stride,
115 flags,
116 xnn_operator_type_channel_shuffle_nc_x8,
117 channel_shuffle_op_out);
118 }
119
xnn_create_channel_shuffle_nc_x32(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)120 enum xnn_status xnn_create_channel_shuffle_nc_x32(
121 size_t groups,
122 size_t group_channels,
123 size_t input_stride,
124 size_t output_stride,
125 uint32_t flags,
126 xnn_operator_t* channel_shuffle_op_out)
127 {
128 return create_channel_shuffle_nc(
129 groups,
130 group_channels,
131 input_stride,
132 output_stride,
133 flags,
134 xnn_operator_type_channel_shuffle_nc_x32,
135 channel_shuffle_op_out);
136 }
137
setup_channel_shuffle_nc(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS (1)])138 static enum xnn_status setup_channel_shuffle_nc(
139 xnn_operator_t channel_shuffle_op,
140 size_t batch_size,
141 const void* input,
142 void* output,
143 uint32_t log2_element_size,
144 const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS(1)])
145 {
146 channel_shuffle_op->state = xnn_run_state_invalid;
147
148 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
149 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
150 xnn_operator_type_to_string(channel_shuffle_op->type));
151 return xnn_status_uninitialized;
152 }
153
154 if (batch_size == 0) {
155 channel_shuffle_op->state = xnn_run_state_skip;
156 return xnn_status_success;
157 }
158
159 channel_shuffle_op->batch_size = batch_size;
160 channel_shuffle_op->input = input;
161 channel_shuffle_op->output = output;
162
163 const size_t groups = channel_shuffle_op->groups;
164 channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
165 .x = input,
166 .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
167 .y = output,
168 .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
169 .n = channel_shuffle_op->group_channels << log2_element_size,
170 .m = groups,
171 };
172 channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
173 channel_shuffle_op->compute.range[0] = batch_size;
174 switch (groups) {
175 case 2:
176 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
177 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
178 break;
179 case 3:
180 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
181 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
182 break;
183 case 4:
184 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
185 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
186 break;
187 default:
188 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
189 channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
190 break;
191 case 0:
192 case 1:
193 XNN_UNREACHABLE;
194 }
195 channel_shuffle_op->state = xnn_run_state_ready;
196
197 return xnn_status_success;
198 }
199
xnn_setup_channel_shuffle_nc_x8(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)200 enum xnn_status xnn_setup_channel_shuffle_nc_x8(
201 xnn_operator_t channel_shuffle_op,
202 size_t batch_size,
203 const void* input,
204 void* output,
205 pthreadpool_t threadpool)
206 {
207 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
208 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
209 xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8),
210 xnn_operator_type_to_string(channel_shuffle_op->type));
211 return xnn_status_invalid_parameter;
212 }
213
214 return setup_channel_shuffle_nc(
215 channel_shuffle_op,
216 batch_size,
217 input,
218 output,
219 0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
220 &xnn_params.x8.zip);
221 }
222
xnn_setup_channel_shuffle_nc_x32(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)223 enum xnn_status xnn_setup_channel_shuffle_nc_x32(
224 xnn_operator_t channel_shuffle_op,
225 size_t batch_size,
226 const void* input,
227 void* output,
228 pthreadpool_t threadpool)
229 {
230 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
231 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
232 xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32),
233 xnn_operator_type_to_string(channel_shuffle_op->type));
234 return xnn_status_invalid_parameter;
235 }
236
237 return setup_channel_shuffle_nc(
238 channel_shuffle_op,
239 batch_size,
240 input,
241 output,
242 2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
243 &xnn_params.x32.zip);
244 }
245