1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph-validation.h>
15*4bdc9457SAndroid Build Coastguard Worker
create_concatenate_operator_helper(const struct xnn_node * node,size_t channels,size_t input_stride,size_t output_stride,struct xnn_operator_data * opdata,size_t index)16*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_concatenate_operator_helper(
17*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node *node,
18*4bdc9457SAndroid Build Coastguard Worker size_t channels,
19*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
20*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
21*4bdc9457SAndroid Build Coastguard Worker struct xnn_operator_data *opdata,
22*4bdc9457SAndroid Build Coastguard Worker size_t index)
23*4bdc9457SAndroid Build Coastguard Worker {
24*4bdc9457SAndroid Build Coastguard Worker switch (node->compute_type) {
25*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_F16_OPERATORS
26*4bdc9457SAndroid Build Coastguard Worker case xnn_compute_type_fp16: {
27*4bdc9457SAndroid Build Coastguard Worker return xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
28*4bdc9457SAndroid Build Coastguard Worker }
29*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_F16_OPERATORS)
30*4bdc9457SAndroid Build Coastguard Worker case xnn_compute_type_fp32: {
31*4bdc9457SAndroid Build Coastguard Worker return xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
32*4bdc9457SAndroid Build Coastguard Worker }
33*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QS8_OPERATORS
34*4bdc9457SAndroid Build Coastguard Worker case xnn_compute_type_qs8:
35*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QS8_OPERATORS)
36*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
37*4bdc9457SAndroid Build Coastguard Worker case xnn_compute_type_qu8:
38*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QU8_OPERATORS)
39*4bdc9457SAndroid Build Coastguard Worker #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
40*4bdc9457SAndroid Build Coastguard Worker {
41*4bdc9457SAndroid Build Coastguard Worker return xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
42*4bdc9457SAndroid Build Coastguard Worker }
43*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
44*4bdc9457SAndroid Build Coastguard Worker default:
45*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
46*4bdc9457SAndroid Build Coastguard Worker }
47*4bdc9457SAndroid Build Coastguard Worker }
48*4bdc9457SAndroid Build Coastguard Worker
create_concatenate2_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)49*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_concatenate2_operator(
50*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node,
51*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* values,
52*4bdc9457SAndroid Build Coastguard Worker size_t num_values,
53*4bdc9457SAndroid Build Coastguard Worker struct xnn_operator_data* opdata,
54*4bdc9457SAndroid Build Coastguard Worker const struct xnn_caches* caches)
55*4bdc9457SAndroid Build Coastguard Worker {
56*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs == 2);
57*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = node->inputs[0];
58*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
59*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_values);
60*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = node->inputs[1];
61*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
62*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_values);
63*4bdc9457SAndroid Build Coastguard Worker
64*4bdc9457SAndroid Build Coastguard Worker assert(node->num_outputs == 1);
65*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = node->outputs[0];
66*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
67*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_values);
68*4bdc9457SAndroid Build Coastguard Worker
69*4bdc9457SAndroid Build Coastguard Worker const size_t axis = node->params.concatenate.axis;
70*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1, channels_1 = 1, channels_2 = 1;
71*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < axis; i++) {
72*4bdc9457SAndroid Build Coastguard Worker batch_size *= values[output_id].shape.dim[i];
73*4bdc9457SAndroid Build Coastguard Worker }
74*4bdc9457SAndroid Build Coastguard Worker
75*4bdc9457SAndroid Build Coastguard Worker for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
76*4bdc9457SAndroid Build Coastguard Worker channels_1 *= values[input1_id].shape.dim[i];
77*4bdc9457SAndroid Build Coastguard Worker channels_2 *= values[input2_id].shape.dim[i];
78*4bdc9457SAndroid Build Coastguard Worker }
79*4bdc9457SAndroid Build Coastguard Worker const size_t output_stride = channels_1 + channels_2;
80*4bdc9457SAndroid Build Coastguard Worker
81*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
82*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
83*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
84*4bdc9457SAndroid Build Coastguard Worker return status;
85*4bdc9457SAndroid Build Coastguard Worker }
86*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
87*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
88*4bdc9457SAndroid Build Coastguard Worker return status;
89*4bdc9457SAndroid Build Coastguard Worker }
90*4bdc9457SAndroid Build Coastguard Worker
91*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[0] = input1_id;
92*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[1] = input2_id;
93*4bdc9457SAndroid Build Coastguard Worker opdata->outputs[0] = output_id;
94*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size = batch_size;
95*4bdc9457SAndroid Build Coastguard Worker
96*4bdc9457SAndroid Build Coastguard Worker return status;
97*4bdc9457SAndroid Build Coastguard Worker }
98*4bdc9457SAndroid Build Coastguard Worker
create_concatenate3_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)99*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_concatenate3_operator(
100*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node,
101*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* values,
102*4bdc9457SAndroid Build Coastguard Worker size_t num_values,
103*4bdc9457SAndroid Build Coastguard Worker struct xnn_operator_data* opdata,
104*4bdc9457SAndroid Build Coastguard Worker const struct xnn_caches* caches)
105*4bdc9457SAndroid Build Coastguard Worker {
106*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs == 3);
107*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = node->inputs[0];
108*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
109*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_values);
110*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = node->inputs[1];
111*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
112*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_values);
113*4bdc9457SAndroid Build Coastguard Worker const uint32_t input3_id = node->inputs[2];
114*4bdc9457SAndroid Build Coastguard Worker assert(input3_id != XNN_INVALID_VALUE_ID);
115*4bdc9457SAndroid Build Coastguard Worker assert(input3_id < num_values);
116*4bdc9457SAndroid Build Coastguard Worker
117*4bdc9457SAndroid Build Coastguard Worker assert(node->num_outputs == 1);
118*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = node->outputs[0];
119*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
120*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_values);
121*4bdc9457SAndroid Build Coastguard Worker
122*4bdc9457SAndroid Build Coastguard Worker const size_t axis = node->params.concatenate.axis;
123*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1;
124*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < axis; i++) {
125*4bdc9457SAndroid Build Coastguard Worker batch_size *= values[output_id].shape.dim[i];
126*4bdc9457SAndroid Build Coastguard Worker }
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
129*4bdc9457SAndroid Build Coastguard Worker channels_1 *= values[input1_id].shape.dim[i];
130*4bdc9457SAndroid Build Coastguard Worker channels_2 *= values[input2_id].shape.dim[i];
131*4bdc9457SAndroid Build Coastguard Worker channels_3 *= values[input3_id].shape.dim[i];
132*4bdc9457SAndroid Build Coastguard Worker }
133*4bdc9457SAndroid Build Coastguard Worker const size_t output_stride = channels_1 + channels_2 + channels_3;
134*4bdc9457SAndroid Build Coastguard Worker
135*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
136*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
137*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
138*4bdc9457SAndroid Build Coastguard Worker return status;
139*4bdc9457SAndroid Build Coastguard Worker }
140*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
141*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
142*4bdc9457SAndroid Build Coastguard Worker return status;
143*4bdc9457SAndroid Build Coastguard Worker }
144*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
145*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
146*4bdc9457SAndroid Build Coastguard Worker return status;
147*4bdc9457SAndroid Build Coastguard Worker }
148*4bdc9457SAndroid Build Coastguard Worker
149*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[0] = input1_id;
150*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[1] = input2_id;
151*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[2] = input3_id;
152*4bdc9457SAndroid Build Coastguard Worker opdata->outputs[0] = output_id;
153*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size = batch_size;
154*4bdc9457SAndroid Build Coastguard Worker
155*4bdc9457SAndroid Build Coastguard Worker return status;
156*4bdc9457SAndroid Build Coastguard Worker }
157*4bdc9457SAndroid Build Coastguard Worker
create_concatenate4_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)158*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_concatenate4_operator(
159*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node,
160*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* values,
161*4bdc9457SAndroid Build Coastguard Worker size_t num_values,
162*4bdc9457SAndroid Build Coastguard Worker struct xnn_operator_data* opdata,
163*4bdc9457SAndroid Build Coastguard Worker const struct xnn_caches* caches)
164*4bdc9457SAndroid Build Coastguard Worker {
165*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs == 4);
166*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = node->inputs[0];
167*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
168*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_values);
169*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = node->inputs[1];
170*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
171*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_values);
172*4bdc9457SAndroid Build Coastguard Worker const uint32_t input3_id = node->inputs[2];
173*4bdc9457SAndroid Build Coastguard Worker assert(input3_id != XNN_INVALID_VALUE_ID);
174*4bdc9457SAndroid Build Coastguard Worker assert(input3_id < num_values);
175*4bdc9457SAndroid Build Coastguard Worker const uint32_t input4_id = node->inputs[3];
176*4bdc9457SAndroid Build Coastguard Worker assert(input4_id != XNN_INVALID_VALUE_ID);
177*4bdc9457SAndroid Build Coastguard Worker assert(input4_id < num_values);
178*4bdc9457SAndroid Build Coastguard Worker
179*4bdc9457SAndroid Build Coastguard Worker assert(node->num_outputs == 1);
180*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = node->outputs[0];
181*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
182*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_values);
183*4bdc9457SAndroid Build Coastguard Worker
184*4bdc9457SAndroid Build Coastguard Worker const size_t axis = node->params.concatenate.axis;
185*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1, channels_4 = 1;
186*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < axis; i++) {
187*4bdc9457SAndroid Build Coastguard Worker batch_size *= values[output_id].shape.dim[i];
188*4bdc9457SAndroid Build Coastguard Worker }
189*4bdc9457SAndroid Build Coastguard Worker
190*4bdc9457SAndroid Build Coastguard Worker for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
191*4bdc9457SAndroid Build Coastguard Worker channels_1 *= values[input1_id].shape.dim[i];
192*4bdc9457SAndroid Build Coastguard Worker channels_2 *= values[input2_id].shape.dim[i];
193*4bdc9457SAndroid Build Coastguard Worker channels_3 *= values[input3_id].shape.dim[i];
194*4bdc9457SAndroid Build Coastguard Worker channels_4 *= values[input4_id].shape.dim[i];
195*4bdc9457SAndroid Build Coastguard Worker }
196*4bdc9457SAndroid Build Coastguard Worker const size_t output_stride = channels_1 + channels_2 + channels_3 + channels_4;
197*4bdc9457SAndroid Build Coastguard Worker
198*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
199*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
200*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
201*4bdc9457SAndroid Build Coastguard Worker return status;
202*4bdc9457SAndroid Build Coastguard Worker }
203*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
204*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
205*4bdc9457SAndroid Build Coastguard Worker return status;
206*4bdc9457SAndroid Build Coastguard Worker }
207*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
208*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
209*4bdc9457SAndroid Build Coastguard Worker return status;
210*4bdc9457SAndroid Build Coastguard Worker }
211*4bdc9457SAndroid Build Coastguard Worker status = create_concatenate_operator_helper(node, channels_4, channels_4, output_stride, opdata, 3);
212*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
213*4bdc9457SAndroid Build Coastguard Worker return status;
214*4bdc9457SAndroid Build Coastguard Worker }
215*4bdc9457SAndroid Build Coastguard Worker
216*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[0] = input1_id;
217*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[1] = input2_id;
218*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[2] = input3_id;
219*4bdc9457SAndroid Build Coastguard Worker opdata->inputs[3] = input4_id;
220*4bdc9457SAndroid Build Coastguard Worker opdata->outputs[0] = output_id;
221*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size = batch_size;
222*4bdc9457SAndroid Build Coastguard Worker
223*4bdc9457SAndroid Build Coastguard Worker return status;
224*4bdc9457SAndroid Build Coastguard Worker }
225*4bdc9457SAndroid Build Coastguard Worker
setup_concatenate_operator_helper(const void * input_data,void * output_data,const struct xnn_operator_data * opdata,size_t index,pthreadpool_t threadpool)226*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_concatenate_operator_helper(
227*4bdc9457SAndroid Build Coastguard Worker const void* input_data,
228*4bdc9457SAndroid Build Coastguard Worker void* output_data,
229*4bdc9457SAndroid Build Coastguard Worker const struct xnn_operator_data *opdata,
230*4bdc9457SAndroid Build Coastguard Worker size_t index,
231*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
232*4bdc9457SAndroid Build Coastguard Worker {
233*4bdc9457SAndroid Build Coastguard Worker // The output pointer of this operator is the sum of all channels of the earlier operators.
234*4bdc9457SAndroid Build Coastguard Worker size_t channels = 0;
235*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < index; i++) {
236*4bdc9457SAndroid Build Coastguard Worker channels += opdata->operator_objects[i]->channels;
237*4bdc9457SAndroid Build Coastguard Worker }
238*4bdc9457SAndroid Build Coastguard Worker
239*4bdc9457SAndroid Build Coastguard Worker switch (opdata->operator_objects[index]->type) {
240*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_F16_OPERATORS
241*4bdc9457SAndroid Build Coastguard Worker case xnn_operator_type_copy_nc_x16: {
242*4bdc9457SAndroid Build Coastguard Worker return xnn_setup_copy_nc_x16(
243*4bdc9457SAndroid Build Coastguard Worker opdata->operator_objects[index],
244*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size,
245*4bdc9457SAndroid Build Coastguard Worker input_data,
246*4bdc9457SAndroid Build Coastguard Worker (uint16_t*) output_data + channels,
247*4bdc9457SAndroid Build Coastguard Worker threadpool);
248*4bdc9457SAndroid Build Coastguard Worker }
249*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_F16_OPERATORS)
250*4bdc9457SAndroid Build Coastguard Worker case xnn_operator_type_copy_nc_x32: {
251*4bdc9457SAndroid Build Coastguard Worker return xnn_setup_copy_nc_x32(
252*4bdc9457SAndroid Build Coastguard Worker opdata->operator_objects[index],
253*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size,
254*4bdc9457SAndroid Build Coastguard Worker input_data,
255*4bdc9457SAndroid Build Coastguard Worker (uint32_t*) output_data + channels,
256*4bdc9457SAndroid Build Coastguard Worker threadpool);
257*4bdc9457SAndroid Build Coastguard Worker }
258*4bdc9457SAndroid Build Coastguard Worker #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
259*4bdc9457SAndroid Build Coastguard Worker case xnn_operator_type_copy_nc_x8: {
260*4bdc9457SAndroid Build Coastguard Worker return xnn_setup_copy_nc_x8(
261*4bdc9457SAndroid Build Coastguard Worker opdata->operator_objects[index],
262*4bdc9457SAndroid Build Coastguard Worker opdata->batch_size,
263*4bdc9457SAndroid Build Coastguard Worker input_data,
264*4bdc9457SAndroid Build Coastguard Worker (uint8_t*) output_data + channels,
265*4bdc9457SAndroid Build Coastguard Worker threadpool);
266*4bdc9457SAndroid Build Coastguard Worker }
267*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
268*4bdc9457SAndroid Build Coastguard Worker default:
269*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
270*4bdc9457SAndroid Build Coastguard Worker }
271*4bdc9457SAndroid Build Coastguard Worker }
272*4bdc9457SAndroid Build Coastguard Worker
setup_concatenate2_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)273*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_concatenate2_operator(
274*4bdc9457SAndroid Build Coastguard Worker const struct xnn_operator_data* opdata,
275*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* blobs,
276*4bdc9457SAndroid Build Coastguard Worker size_t num_blobs,
277*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
278*4bdc9457SAndroid Build Coastguard Worker {
279*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = opdata->inputs[0];
280*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
281*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_blobs);
282*4bdc9457SAndroid Build Coastguard Worker
283*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = opdata->inputs[1];
284*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
285*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_blobs);
286*4bdc9457SAndroid Build Coastguard Worker
287*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = opdata->outputs[0];
288*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
289*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_blobs);
290*4bdc9457SAndroid Build Coastguard Worker
291*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input1_blob = blobs + input1_id;
292*4bdc9457SAndroid Build Coastguard Worker const void* input1_data = input1_blob->data;
293*4bdc9457SAndroid Build Coastguard Worker assert(input1_data != NULL);
294*4bdc9457SAndroid Build Coastguard Worker
295*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input2_blob = blobs + input2_id;
296*4bdc9457SAndroid Build Coastguard Worker const void* input2_data = input2_blob->data;
297*4bdc9457SAndroid Build Coastguard Worker assert(input2_data != NULL);
298*4bdc9457SAndroid Build Coastguard Worker
299*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* output_blob = blobs + output_id;
300*4bdc9457SAndroid Build Coastguard Worker void* output_data = output_blob->data;
301*4bdc9457SAndroid Build Coastguard Worker assert(output_data != NULL);
302*4bdc9457SAndroid Build Coastguard Worker
303*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
304*4bdc9457SAndroid Build Coastguard Worker
305*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
306*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
307*4bdc9457SAndroid Build Coastguard Worker return status;
308*4bdc9457SAndroid Build Coastguard Worker }
309*4bdc9457SAndroid Build Coastguard Worker return setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
310*4bdc9457SAndroid Build Coastguard Worker }
311*4bdc9457SAndroid Build Coastguard Worker
setup_concatenate3_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)312*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_concatenate3_operator(
313*4bdc9457SAndroid Build Coastguard Worker const struct xnn_operator_data* opdata,
314*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* blobs,
315*4bdc9457SAndroid Build Coastguard Worker size_t num_blobs,
316*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
317*4bdc9457SAndroid Build Coastguard Worker {
318*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = opdata->inputs[0];
319*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
320*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_blobs);
321*4bdc9457SAndroid Build Coastguard Worker
322*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = opdata->inputs[1];
323*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
324*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_blobs);
325*4bdc9457SAndroid Build Coastguard Worker
326*4bdc9457SAndroid Build Coastguard Worker const uint32_t input3_id = opdata->inputs[2];
327*4bdc9457SAndroid Build Coastguard Worker assert(input3_id != XNN_INVALID_VALUE_ID);
328*4bdc9457SAndroid Build Coastguard Worker assert(input3_id < num_blobs);
329*4bdc9457SAndroid Build Coastguard Worker
330*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = opdata->outputs[0];
331*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
332*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_blobs);
333*4bdc9457SAndroid Build Coastguard Worker
334*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input1_blob = blobs + input1_id;
335*4bdc9457SAndroid Build Coastguard Worker const void* input1_data = input1_blob->data;
336*4bdc9457SAndroid Build Coastguard Worker assert(input1_data != NULL);
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input2_blob = blobs + input2_id;
339*4bdc9457SAndroid Build Coastguard Worker const void* input2_data = input2_blob->data;
340*4bdc9457SAndroid Build Coastguard Worker assert(input2_data != NULL);
341*4bdc9457SAndroid Build Coastguard Worker
342*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input3_blob = blobs + input3_id;
343*4bdc9457SAndroid Build Coastguard Worker const void* input3_data = input3_blob->data;
344*4bdc9457SAndroid Build Coastguard Worker assert(input3_data != NULL);
345*4bdc9457SAndroid Build Coastguard Worker
346*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* output_blob = blobs + output_id;
347*4bdc9457SAndroid Build Coastguard Worker void* output_data = output_blob->data;
348*4bdc9457SAndroid Build Coastguard Worker assert(output_data != NULL);
349*4bdc9457SAndroid Build Coastguard Worker
350*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
351*4bdc9457SAndroid Build Coastguard Worker
352*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
353*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
354*4bdc9457SAndroid Build Coastguard Worker return status;
355*4bdc9457SAndroid Build Coastguard Worker }
356*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
357*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
358*4bdc9457SAndroid Build Coastguard Worker return status;
359*4bdc9457SAndroid Build Coastguard Worker }
360*4bdc9457SAndroid Build Coastguard Worker return setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
361*4bdc9457SAndroid Build Coastguard Worker }
362*4bdc9457SAndroid Build Coastguard Worker
setup_concatenate4_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)363*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_concatenate4_operator(
364*4bdc9457SAndroid Build Coastguard Worker const struct xnn_operator_data* opdata,
365*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* blobs,
366*4bdc9457SAndroid Build Coastguard Worker size_t num_blobs,
367*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
368*4bdc9457SAndroid Build Coastguard Worker {
369*4bdc9457SAndroid Build Coastguard Worker const uint32_t input1_id = opdata->inputs[0];
370*4bdc9457SAndroid Build Coastguard Worker assert(input1_id != XNN_INVALID_VALUE_ID);
371*4bdc9457SAndroid Build Coastguard Worker assert(input1_id < num_blobs);
372*4bdc9457SAndroid Build Coastguard Worker
373*4bdc9457SAndroid Build Coastguard Worker const uint32_t input2_id = opdata->inputs[1];
374*4bdc9457SAndroid Build Coastguard Worker assert(input2_id != XNN_INVALID_VALUE_ID);
375*4bdc9457SAndroid Build Coastguard Worker assert(input2_id < num_blobs);
376*4bdc9457SAndroid Build Coastguard Worker
377*4bdc9457SAndroid Build Coastguard Worker const uint32_t input3_id = opdata->inputs[2];
378*4bdc9457SAndroid Build Coastguard Worker assert(input3_id != XNN_INVALID_VALUE_ID);
379*4bdc9457SAndroid Build Coastguard Worker assert(input3_id < num_blobs);
380*4bdc9457SAndroid Build Coastguard Worker
381*4bdc9457SAndroid Build Coastguard Worker const uint32_t input4_id = opdata->inputs[3];
382*4bdc9457SAndroid Build Coastguard Worker assert(input4_id != XNN_INVALID_VALUE_ID);
383*4bdc9457SAndroid Build Coastguard Worker assert(input4_id < num_blobs);
384*4bdc9457SAndroid Build Coastguard Worker
385*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = opdata->outputs[0];
386*4bdc9457SAndroid Build Coastguard Worker assert(output_id != XNN_INVALID_VALUE_ID);
387*4bdc9457SAndroid Build Coastguard Worker assert(output_id < num_blobs);
388*4bdc9457SAndroid Build Coastguard Worker
389*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input1_blob = blobs + input1_id;
390*4bdc9457SAndroid Build Coastguard Worker const void* input1_data = input1_blob->data;
391*4bdc9457SAndroid Build Coastguard Worker assert(input1_data != NULL);
392*4bdc9457SAndroid Build Coastguard Worker
393*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input2_blob = blobs + input2_id;
394*4bdc9457SAndroid Build Coastguard Worker const void* input2_data = input2_blob->data;
395*4bdc9457SAndroid Build Coastguard Worker assert(input2_data != NULL);
396*4bdc9457SAndroid Build Coastguard Worker
397*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input3_blob = blobs + input3_id;
398*4bdc9457SAndroid Build Coastguard Worker const void* input3_data = input3_blob->data;
399*4bdc9457SAndroid Build Coastguard Worker assert(input3_data != NULL);
400*4bdc9457SAndroid Build Coastguard Worker
401*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* input4_blob = blobs + input4_id;
402*4bdc9457SAndroid Build Coastguard Worker const void* input4_data = input4_blob->data;
403*4bdc9457SAndroid Build Coastguard Worker assert(input4_data != NULL);
404*4bdc9457SAndroid Build Coastguard Worker
405*4bdc9457SAndroid Build Coastguard Worker const struct xnn_blob* output_blob = blobs + output_id;
406*4bdc9457SAndroid Build Coastguard Worker void* output_data = output_blob->data;
407*4bdc9457SAndroid Build Coastguard Worker assert(output_data != NULL);
408*4bdc9457SAndroid Build Coastguard Worker
409*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
410*4bdc9457SAndroid Build Coastguard Worker
411*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
412*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
413*4bdc9457SAndroid Build Coastguard Worker return status;
414*4bdc9457SAndroid Build Coastguard Worker }
415*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
416*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
417*4bdc9457SAndroid Build Coastguard Worker return status;
418*4bdc9457SAndroid Build Coastguard Worker }
419*4bdc9457SAndroid Build Coastguard Worker status = setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
420*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
421*4bdc9457SAndroid Build Coastguard Worker return status;
422*4bdc9457SAndroid Build Coastguard Worker }
423*4bdc9457SAndroid Build Coastguard Worker return setup_concatenate_operator_helper(input4_data, output_data, opdata, 3, threadpool);
424*4bdc9457SAndroid Build Coastguard Worker }
425*4bdc9457SAndroid Build Coastguard Worker
check_input_value(xnn_subgraph_t subgraph,size_t axis,uint32_t input_id,uint32_t output_id,size_t nth,enum xnn_node_type node_type)426*4bdc9457SAndroid Build Coastguard Worker enum xnn_status check_input_value(
427*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
428*4bdc9457SAndroid Build Coastguard Worker size_t axis,
429*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id,
430*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
431*4bdc9457SAndroid Build Coastguard Worker size_t nth,
432*4bdc9457SAndroid Build Coastguard Worker enum xnn_node_type node_type)
433*4bdc9457SAndroid Build Coastguard Worker {
434*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
435*4bdc9457SAndroid Build Coastguard Worker if ((status = xnn_subgraph_check_nth_input_node_id(node_type, input_id, subgraph->num_values, nth)) !=
436*4bdc9457SAndroid Build Coastguard Worker xnn_status_success) {
437*4bdc9457SAndroid Build Coastguard Worker return status;
438*4bdc9457SAndroid Build Coastguard Worker }
439*4bdc9457SAndroid Build Coastguard Worker
440*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* input_value = &subgraph->values[input_id];
441*4bdc9457SAndroid Build Coastguard Worker status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
442*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
443*4bdc9457SAndroid Build Coastguard Worker return status;
444*4bdc9457SAndroid Build Coastguard Worker }
445*4bdc9457SAndroid Build Coastguard Worker
446*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* output_value = &subgraph->values[output_id];
447*4bdc9457SAndroid Build Coastguard Worker if (input_value->shape.num_dims != output_value->shape.num_dims) {
448*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
449*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with input %zu ID #%" PRIu32
450*4bdc9457SAndroid Build Coastguard Worker ": mismatch number of dimensions, input %zu has %zu, output has %zu",
451*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), nth, input_id, nth, input_value->shape.num_dims,
452*4bdc9457SAndroid Build Coastguard Worker output_value->shape.num_dims);
453*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
454*4bdc9457SAndroid Build Coastguard Worker }
455*4bdc9457SAndroid Build Coastguard Worker
456*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < input_value->shape.num_dims; i++) {
457*4bdc9457SAndroid Build Coastguard Worker if (i != axis && input_value->shape.dim[i] != output_value->shape.dim[i]) {
458*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
459*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with input ID #%" PRIu32
460*4bdc9457SAndroid Build Coastguard Worker ": mismatch dimension %zu, input %zu has %zu, output has %zu",
461*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), input_id, i, nth, input_value->shape.dim[i], output_value->shape.dim[i]);
462*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
463*4bdc9457SAndroid Build Coastguard Worker }
464*4bdc9457SAndroid Build Coastguard Worker }
465*4bdc9457SAndroid Build Coastguard Worker
466*4bdc9457SAndroid Build Coastguard Worker status = xnn_subgraph_check_datatype_matches(node_type, input_id, input_value, output_id, output_value);
467*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
468*4bdc9457SAndroid Build Coastguard Worker return status;
469*4bdc9457SAndroid Build Coastguard Worker }
470*4bdc9457SAndroid Build Coastguard Worker
471*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
472*4bdc9457SAndroid Build Coastguard Worker }
473*4bdc9457SAndroid Build Coastguard Worker
474*4bdc9457SAndroid Build Coastguard Worker #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
check_input_compute_type(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,const char * nth,enum xnn_node_type node_type)475*4bdc9457SAndroid Build Coastguard Worker enum xnn_status check_input_compute_type(
476*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
477*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id,
478*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
479*4bdc9457SAndroid Build Coastguard Worker const char* nth,
480*4bdc9457SAndroid Build Coastguard Worker enum xnn_node_type node_type)
481*4bdc9457SAndroid Build Coastguard Worker {
482*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* input_value = &subgraph->values[input_id];
483*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* output_value = &subgraph->values[output_id];
484*4bdc9457SAndroid Build Coastguard Worker if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
485*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
486*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
487*4bdc9457SAndroid Build Coastguard Worker ": mismatching quantization zero point across the %s input (%" PRId32 ") and the output (%" PRId32 ")",
488*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), input_id, output_id,
489*4bdc9457SAndroid Build Coastguard Worker nth, input_value->quantization.zero_point, output_value->quantization.zero_point);
490*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
491*4bdc9457SAndroid Build Coastguard Worker }
492*4bdc9457SAndroid Build Coastguard Worker if (input_value->quantization.scale != output_value->quantization.scale) {
493*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
494*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
495*4bdc9457SAndroid Build Coastguard Worker ": mismatching quantization scale across the %s input (%.7g) and the output (%.7g)",
496*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), input_id, output_id,
497*4bdc9457SAndroid Build Coastguard Worker nth, input_value->quantization.scale, output_value->quantization.scale);
498*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
499*4bdc9457SAndroid Build Coastguard Worker }
500*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
501*4bdc9457SAndroid Build Coastguard Worker }
502*4bdc9457SAndroid Build Coastguard Worker #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
503*4bdc9457SAndroid Build Coastguard Worker
xnn_define_concatenate_n(enum xnn_node_type node_type,xnn_subgraph_t subgraph,size_t axis,size_t num_inputs,uint32_t * input_ids,uint32_t output_id,uint32_t flags)504*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_concatenate_n(
505*4bdc9457SAndroid Build Coastguard Worker enum xnn_node_type node_type,
506*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
507*4bdc9457SAndroid Build Coastguard Worker size_t axis,
508*4bdc9457SAndroid Build Coastguard Worker size_t num_inputs,
509*4bdc9457SAndroid Build Coastguard Worker uint32_t* input_ids,
510*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
511*4bdc9457SAndroid Build Coastguard Worker uint32_t flags)
512*4bdc9457SAndroid Build Coastguard Worker {
513*4bdc9457SAndroid Build Coastguard Worker assert(num_inputs >= 2);
514*4bdc9457SAndroid Build Coastguard Worker assert(num_inputs <= 4);
515*4bdc9457SAndroid Build Coastguard Worker
516*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status;
517*4bdc9457SAndroid Build Coastguard Worker if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
518*4bdc9457SAndroid Build Coastguard Worker return status;
519*4bdc9457SAndroid Build Coastguard Worker }
520*4bdc9457SAndroid Build Coastguard Worker
521*4bdc9457SAndroid Build Coastguard Worker status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
522*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
523*4bdc9457SAndroid Build Coastguard Worker return status;
524*4bdc9457SAndroid Build Coastguard Worker }
525*4bdc9457SAndroid Build Coastguard Worker
526*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* output_value = &subgraph->values[output_id];
527*4bdc9457SAndroid Build Coastguard Worker
528*4bdc9457SAndroid Build Coastguard Worker status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
529*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
530*4bdc9457SAndroid Build Coastguard Worker return status;
531*4bdc9457SAndroid Build Coastguard Worker }
532*4bdc9457SAndroid Build Coastguard Worker
533*4bdc9457SAndroid Build Coastguard Worker if (axis >= output_value->shape.num_dims) {
534*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
535*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with the output ID #%" PRIu32
536*4bdc9457SAndroid Build Coastguard Worker ": axis (%zu) exceeds the number of dimensions (%zu)",
537*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.num_dims);
538*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
539*4bdc9457SAndroid Build Coastguard Worker }
540*4bdc9457SAndroid Build Coastguard Worker
541*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_inputs; i++) {
542*4bdc9457SAndroid Build Coastguard Worker status = check_input_value(subgraph, axis, input_ids[i], output_id, i+1, node_type);
543*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
544*4bdc9457SAndroid Build Coastguard Worker return status;
545*4bdc9457SAndroid Build Coastguard Worker }
546*4bdc9457SAndroid Build Coastguard Worker }
547*4bdc9457SAndroid Build Coastguard Worker
548*4bdc9457SAndroid Build Coastguard Worker size_t input_axis_dimensions_sum = 0;
549*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_inputs; i++) {
550*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* input_value = &subgraph->values[input_ids[i]];
551*4bdc9457SAndroid Build Coastguard Worker input_axis_dimensions_sum += input_value->shape.dim[axis];
552*4bdc9457SAndroid Build Coastguard Worker }
553*4bdc9457SAndroid Build Coastguard Worker
554*4bdc9457SAndroid Build Coastguard Worker if (output_value->shape.dim[axis] != input_axis_dimensions_sum) {
555*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
556*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with output ID #%" PRIu32
557*4bdc9457SAndroid Build Coastguard Worker ": mismatch axis dimension %zu, output has %zu, sum of input dimensions is %zu",
558*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.dim[axis], input_axis_dimensions_sum);
559*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
560*4bdc9457SAndroid Build Coastguard Worker }
561*4bdc9457SAndroid Build Coastguard Worker
562*4bdc9457SAndroid Build Coastguard Worker enum xnn_compute_type compute_type = xnn_compute_type_invalid;
563*4bdc9457SAndroid Build Coastguard Worker switch (output_value->datatype) {
564*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_F16_OPERATORS
565*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp16:
566*4bdc9457SAndroid Build Coastguard Worker compute_type = xnn_compute_type_fp16;
567*4bdc9457SAndroid Build Coastguard Worker break;
568*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_F16_OPERATORS)
569*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp32:
570*4bdc9457SAndroid Build Coastguard Worker compute_type = xnn_compute_type_fp32;
571*4bdc9457SAndroid Build Coastguard Worker break;
572*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QS8_OPERATORS
573*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qint8:
574*4bdc9457SAndroid Build Coastguard Worker compute_type = xnn_compute_type_qs8;
575*4bdc9457SAndroid Build Coastguard Worker break;
576*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QS8_OPERATORS)
577*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
578*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_quint8:
579*4bdc9457SAndroid Build Coastguard Worker compute_type = xnn_compute_type_qu8;
580*4bdc9457SAndroid Build Coastguard Worker break;
581*4bdc9457SAndroid Build Coastguard Worker #endif // !defined(XNN_NO_QU8_OPERATORS)
582*4bdc9457SAndroid Build Coastguard Worker default:
583*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
584*4bdc9457SAndroid Build Coastguard Worker "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
585*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node_type), output_id,
586*4bdc9457SAndroid Build Coastguard Worker xnn_datatype_to_string(output_value->datatype), output_value->datatype);
587*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
588*4bdc9457SAndroid Build Coastguard Worker }
589*4bdc9457SAndroid Build Coastguard Worker
590*4bdc9457SAndroid Build Coastguard Worker #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
591*4bdc9457SAndroid Build Coastguard Worker if (compute_type == xnn_compute_type_qs8 || compute_type == xnn_compute_type_qu8) {
592*4bdc9457SAndroid Build Coastguard Worker check_input_compute_type(subgraph, input_ids[0], output_id, "first", node_type);
593*4bdc9457SAndroid Build Coastguard Worker check_input_compute_type(subgraph, input_ids[1], output_id, "second", node_type);
594*4bdc9457SAndroid Build Coastguard Worker }
595*4bdc9457SAndroid Build Coastguard Worker if (num_inputs > 2) {
596*4bdc9457SAndroid Build Coastguard Worker check_input_compute_type(subgraph, input_ids[2], output_id, "third", node_type);
597*4bdc9457SAndroid Build Coastguard Worker }
598*4bdc9457SAndroid Build Coastguard Worker if (num_inputs > 3) {
599*4bdc9457SAndroid Build Coastguard Worker check_input_compute_type(subgraph, input_ids[3], output_id, "fourth", node_type);
600*4bdc9457SAndroid Build Coastguard Worker }
601*4bdc9457SAndroid Build Coastguard Worker #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
602*4bdc9457SAndroid Build Coastguard Worker
603*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = xnn_subgraph_new_node(subgraph);
604*4bdc9457SAndroid Build Coastguard Worker if (node == NULL) {
605*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
606*4bdc9457SAndroid Build Coastguard Worker }
607*4bdc9457SAndroid Build Coastguard Worker
608*4bdc9457SAndroid Build Coastguard Worker node->params.concatenate.axis = axis;
609*4bdc9457SAndroid Build Coastguard Worker node->type = node_type;
610*4bdc9457SAndroid Build Coastguard Worker node->compute_type = compute_type;
611*4bdc9457SAndroid Build Coastguard Worker node->num_inputs = num_inputs;
612*4bdc9457SAndroid Build Coastguard Worker node->inputs[0] = input_ids[0];
613*4bdc9457SAndroid Build Coastguard Worker node->inputs[1] = input_ids[1];
614*4bdc9457SAndroid Build Coastguard Worker node->num_outputs = 1;
615*4bdc9457SAndroid Build Coastguard Worker node->outputs[0] = output_id;
616*4bdc9457SAndroid Build Coastguard Worker node->flags = flags;
617*4bdc9457SAndroid Build Coastguard Worker
618*4bdc9457SAndroid Build Coastguard Worker switch (num_inputs) {
619*4bdc9457SAndroid Build Coastguard Worker case 2:
620*4bdc9457SAndroid Build Coastguard Worker node->create = create_concatenate2_operator;
621*4bdc9457SAndroid Build Coastguard Worker node->setup = setup_concatenate2_operator;
622*4bdc9457SAndroid Build Coastguard Worker break;
623*4bdc9457SAndroid Build Coastguard Worker case 3:
624*4bdc9457SAndroid Build Coastguard Worker node->create = create_concatenate3_operator;
625*4bdc9457SAndroid Build Coastguard Worker node->setup = setup_concatenate3_operator;
626*4bdc9457SAndroid Build Coastguard Worker node->inputs[2] = input_ids[2];
627*4bdc9457SAndroid Build Coastguard Worker break;
628*4bdc9457SAndroid Build Coastguard Worker case 4:
629*4bdc9457SAndroid Build Coastguard Worker node->create = create_concatenate4_operator;
630*4bdc9457SAndroid Build Coastguard Worker node->setup = setup_concatenate4_operator;
631*4bdc9457SAndroid Build Coastguard Worker node->inputs[2] = input_ids[2];
632*4bdc9457SAndroid Build Coastguard Worker node->inputs[3] = input_ids[3];
633*4bdc9457SAndroid Build Coastguard Worker break;
634*4bdc9457SAndroid Build Coastguard Worker default:
635*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
636*4bdc9457SAndroid Build Coastguard Worker }
637*4bdc9457SAndroid Build Coastguard Worker
638*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
639*4bdc9457SAndroid Build Coastguard Worker }
640*4bdc9457SAndroid Build Coastguard Worker
xnn_define_concatenate2(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)641*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_concatenate2(
642*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
643*4bdc9457SAndroid Build Coastguard Worker size_t axis,
644*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id,
645*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id,
646*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
647*4bdc9457SAndroid Build Coastguard Worker uint32_t flags)
648*4bdc9457SAndroid Build Coastguard Worker {
649*4bdc9457SAndroid Build Coastguard Worker uint32_t input_ids[2] = { input1_id, input2_id };
650*4bdc9457SAndroid Build Coastguard Worker return xnn_define_concatenate_n(
651*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_concatenate2, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
652*4bdc9457SAndroid Build Coastguard Worker }
653*4bdc9457SAndroid Build Coastguard Worker
xnn_define_concatenate3(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t output_id,uint32_t flags)654*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_concatenate3(
655*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
656*4bdc9457SAndroid Build Coastguard Worker size_t axis,
657*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id,
658*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id,
659*4bdc9457SAndroid Build Coastguard Worker uint32_t input3_id,
660*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
661*4bdc9457SAndroid Build Coastguard Worker uint32_t flags)
662*4bdc9457SAndroid Build Coastguard Worker {
663*4bdc9457SAndroid Build Coastguard Worker uint32_t input_ids[3] = { input1_id, input2_id, input3_id };
664*4bdc9457SAndroid Build Coastguard Worker return xnn_define_concatenate_n(
665*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_concatenate3, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
666*4bdc9457SAndroid Build Coastguard Worker }
667*4bdc9457SAndroid Build Coastguard Worker
xnn_define_concatenate4(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t input4_id,uint32_t output_id,uint32_t flags)668*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_concatenate4(
669*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
670*4bdc9457SAndroid Build Coastguard Worker size_t axis,
671*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id,
672*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id,
673*4bdc9457SAndroid Build Coastguard Worker uint32_t input3_id,
674*4bdc9457SAndroid Build Coastguard Worker uint32_t input4_id,
675*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id,
676*4bdc9457SAndroid Build Coastguard Worker uint32_t flags)
677*4bdc9457SAndroid Build Coastguard Worker {
678*4bdc9457SAndroid Build Coastguard Worker uint32_t input_ids[4] = { input1_id, input2_id, input3_id, input4_id };
679*4bdc9457SAndroid Build Coastguard Worker return xnn_define_concatenate_n(
680*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_concatenate4, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
681*4bdc9457SAndroid Build Coastguard Worker }
682