1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <stdint.h>
8
9 #include <xnnpack.h>
10 #include <xnnpack/log.h>
11 #include <xnnpack/operator.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 #include <xnnpack/subgraph-validation.h>
15
create_concatenate_operator_helper(const struct xnn_node * node,size_t channels,size_t input_stride,size_t output_stride,struct xnn_operator_data * opdata,size_t index)16 static enum xnn_status create_concatenate_operator_helper(
17 const struct xnn_node *node,
18 size_t channels,
19 size_t input_stride,
20 size_t output_stride,
21 struct xnn_operator_data *opdata,
22 size_t index)
23 {
24 switch (node->compute_type) {
25 #ifndef XNN_NO_F16_OPERATORS
26 case xnn_compute_type_fp16: {
27 return xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
28 }
29 #endif // !defined(XNN_NO_F16_OPERATORS)
30 case xnn_compute_type_fp32: {
31 return xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
32 }
33 #ifndef XNN_NO_QS8_OPERATORS
34 case xnn_compute_type_qs8:
35 #endif // !defined(XNN_NO_QS8_OPERATORS)
36 #ifndef XNN_NO_QU8_OPERATORS
37 case xnn_compute_type_qu8:
38 #endif // !defined(XNN_NO_QU8_OPERATORS)
39 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
40 {
41 return xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
42 }
43 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
44 default:
45 XNN_UNREACHABLE;
46 }
47 }
48
create_concatenate2_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)49 static enum xnn_status create_concatenate2_operator(
50 const struct xnn_node* node,
51 const struct xnn_value* values,
52 size_t num_values,
53 struct xnn_operator_data* opdata,
54 const struct xnn_caches* caches)
55 {
56 assert(node->num_inputs == 2);
57 const uint32_t input1_id = node->inputs[0];
58 assert(input1_id != XNN_INVALID_VALUE_ID);
59 assert(input1_id < num_values);
60 const uint32_t input2_id = node->inputs[1];
61 assert(input2_id != XNN_INVALID_VALUE_ID);
62 assert(input2_id < num_values);
63
64 assert(node->num_outputs == 1);
65 const uint32_t output_id = node->outputs[0];
66 assert(output_id != XNN_INVALID_VALUE_ID);
67 assert(output_id < num_values);
68
69 const size_t axis = node->params.concatenate.axis;
70 size_t batch_size = 1, channels_1 = 1, channels_2 = 1;
71 for (size_t i = 0; i < axis; i++) {
72 batch_size *= values[output_id].shape.dim[i];
73 }
74
75 for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
76 channels_1 *= values[input1_id].shape.dim[i];
77 channels_2 *= values[input2_id].shape.dim[i];
78 }
79 const size_t output_stride = channels_1 + channels_2;
80
81 enum xnn_status status;
82 status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
83 if (status != xnn_status_success) {
84 return status;
85 }
86 status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
87 if (status != xnn_status_success) {
88 return status;
89 }
90
91 opdata->inputs[0] = input1_id;
92 opdata->inputs[1] = input2_id;
93 opdata->outputs[0] = output_id;
94 opdata->batch_size = batch_size;
95
96 return status;
97 }
98
create_concatenate3_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)99 static enum xnn_status create_concatenate3_operator(
100 const struct xnn_node* node,
101 const struct xnn_value* values,
102 size_t num_values,
103 struct xnn_operator_data* opdata,
104 const struct xnn_caches* caches)
105 {
106 assert(node->num_inputs == 3);
107 const uint32_t input1_id = node->inputs[0];
108 assert(input1_id != XNN_INVALID_VALUE_ID);
109 assert(input1_id < num_values);
110 const uint32_t input2_id = node->inputs[1];
111 assert(input2_id != XNN_INVALID_VALUE_ID);
112 assert(input2_id < num_values);
113 const uint32_t input3_id = node->inputs[2];
114 assert(input3_id != XNN_INVALID_VALUE_ID);
115 assert(input3_id < num_values);
116
117 assert(node->num_outputs == 1);
118 const uint32_t output_id = node->outputs[0];
119 assert(output_id != XNN_INVALID_VALUE_ID);
120 assert(output_id < num_values);
121
122 const size_t axis = node->params.concatenate.axis;
123 size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1;
124 for (size_t i = 0; i < axis; i++) {
125 batch_size *= values[output_id].shape.dim[i];
126 }
127
128 for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
129 channels_1 *= values[input1_id].shape.dim[i];
130 channels_2 *= values[input2_id].shape.dim[i];
131 channels_3 *= values[input3_id].shape.dim[i];
132 }
133 const size_t output_stride = channels_1 + channels_2 + channels_3;
134
135 enum xnn_status status;
136 status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
137 if (status != xnn_status_success) {
138 return status;
139 }
140 status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
141 if (status != xnn_status_success) {
142 return status;
143 }
144 status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
145 if (status != xnn_status_success) {
146 return status;
147 }
148
149 opdata->inputs[0] = input1_id;
150 opdata->inputs[1] = input2_id;
151 opdata->inputs[2] = input3_id;
152 opdata->outputs[0] = output_id;
153 opdata->batch_size = batch_size;
154
155 return status;
156 }
157
create_concatenate4_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)158 static enum xnn_status create_concatenate4_operator(
159 const struct xnn_node* node,
160 const struct xnn_value* values,
161 size_t num_values,
162 struct xnn_operator_data* opdata,
163 const struct xnn_caches* caches)
164 {
165 assert(node->num_inputs == 4);
166 const uint32_t input1_id = node->inputs[0];
167 assert(input1_id != XNN_INVALID_VALUE_ID);
168 assert(input1_id < num_values);
169 const uint32_t input2_id = node->inputs[1];
170 assert(input2_id != XNN_INVALID_VALUE_ID);
171 assert(input2_id < num_values);
172 const uint32_t input3_id = node->inputs[2];
173 assert(input3_id != XNN_INVALID_VALUE_ID);
174 assert(input3_id < num_values);
175 const uint32_t input4_id = node->inputs[3];
176 assert(input4_id != XNN_INVALID_VALUE_ID);
177 assert(input4_id < num_values);
178
179 assert(node->num_outputs == 1);
180 const uint32_t output_id = node->outputs[0];
181 assert(output_id != XNN_INVALID_VALUE_ID);
182 assert(output_id < num_values);
183
184 const size_t axis = node->params.concatenate.axis;
185 size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1, channels_4 = 1;
186 for (size_t i = 0; i < axis; i++) {
187 batch_size *= values[output_id].shape.dim[i];
188 }
189
190 for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
191 channels_1 *= values[input1_id].shape.dim[i];
192 channels_2 *= values[input2_id].shape.dim[i];
193 channels_3 *= values[input3_id].shape.dim[i];
194 channels_4 *= values[input4_id].shape.dim[i];
195 }
196 const size_t output_stride = channels_1 + channels_2 + channels_3 + channels_4;
197
198 enum xnn_status status;
199 status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
200 if (status != xnn_status_success) {
201 return status;
202 }
203 status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
204 if (status != xnn_status_success) {
205 return status;
206 }
207 status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
208 if (status != xnn_status_success) {
209 return status;
210 }
211 status = create_concatenate_operator_helper(node, channels_4, channels_4, output_stride, opdata, 3);
212 if (status != xnn_status_success) {
213 return status;
214 }
215
216 opdata->inputs[0] = input1_id;
217 opdata->inputs[1] = input2_id;
218 opdata->inputs[2] = input3_id;
219 opdata->inputs[3] = input4_id;
220 opdata->outputs[0] = output_id;
221 opdata->batch_size = batch_size;
222
223 return status;
224 }
225
setup_concatenate_operator_helper(const void * input_data,void * output_data,const struct xnn_operator_data * opdata,size_t index,pthreadpool_t threadpool)226 static enum xnn_status setup_concatenate_operator_helper(
227 const void* input_data,
228 void* output_data,
229 const struct xnn_operator_data *opdata,
230 size_t index,
231 pthreadpool_t threadpool)
232 {
233 // The output pointer of this operator is the sum of all channels of the earlier operators.
234 size_t channels = 0;
235 for (size_t i = 0; i < index; i++) {
236 channels += opdata->operator_objects[i]->channels;
237 }
238
239 switch (opdata->operator_objects[index]->type) {
240 #ifndef XNN_NO_F16_OPERATORS
241 case xnn_operator_type_copy_nc_x16: {
242 return xnn_setup_copy_nc_x16(
243 opdata->operator_objects[index],
244 opdata->batch_size,
245 input_data,
246 (uint16_t*) output_data + channels,
247 threadpool);
248 }
249 #endif // !defined(XNN_NO_F16_OPERATORS)
250 case xnn_operator_type_copy_nc_x32: {
251 return xnn_setup_copy_nc_x32(
252 opdata->operator_objects[index],
253 opdata->batch_size,
254 input_data,
255 (uint32_t*) output_data + channels,
256 threadpool);
257 }
258 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
259 case xnn_operator_type_copy_nc_x8: {
260 return xnn_setup_copy_nc_x8(
261 opdata->operator_objects[index],
262 opdata->batch_size,
263 input_data,
264 (uint8_t*) output_data + channels,
265 threadpool);
266 }
267 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
268 default:
269 XNN_UNREACHABLE;
270 }
271 }
272
setup_concatenate2_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)273 static enum xnn_status setup_concatenate2_operator(
274 const struct xnn_operator_data* opdata,
275 const struct xnn_blob* blobs,
276 size_t num_blobs,
277 pthreadpool_t threadpool)
278 {
279 const uint32_t input1_id = opdata->inputs[0];
280 assert(input1_id != XNN_INVALID_VALUE_ID);
281 assert(input1_id < num_blobs);
282
283 const uint32_t input2_id = opdata->inputs[1];
284 assert(input2_id != XNN_INVALID_VALUE_ID);
285 assert(input2_id < num_blobs);
286
287 const uint32_t output_id = opdata->outputs[0];
288 assert(output_id != XNN_INVALID_VALUE_ID);
289 assert(output_id < num_blobs);
290
291 const struct xnn_blob* input1_blob = blobs + input1_id;
292 const void* input1_data = input1_blob->data;
293 assert(input1_data != NULL);
294
295 const struct xnn_blob* input2_blob = blobs + input2_id;
296 const void* input2_data = input2_blob->data;
297 assert(input2_data != NULL);
298
299 const struct xnn_blob* output_blob = blobs + output_id;
300 void* output_data = output_blob->data;
301 assert(output_data != NULL);
302
303 enum xnn_status status;
304
305 status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
306 if (status != xnn_status_success) {
307 return status;
308 }
309 return setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
310 }
311
setup_concatenate3_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)312 static enum xnn_status setup_concatenate3_operator(
313 const struct xnn_operator_data* opdata,
314 const struct xnn_blob* blobs,
315 size_t num_blobs,
316 pthreadpool_t threadpool)
317 {
318 const uint32_t input1_id = opdata->inputs[0];
319 assert(input1_id != XNN_INVALID_VALUE_ID);
320 assert(input1_id < num_blobs);
321
322 const uint32_t input2_id = opdata->inputs[1];
323 assert(input2_id != XNN_INVALID_VALUE_ID);
324 assert(input2_id < num_blobs);
325
326 const uint32_t input3_id = opdata->inputs[2];
327 assert(input3_id != XNN_INVALID_VALUE_ID);
328 assert(input3_id < num_blobs);
329
330 const uint32_t output_id = opdata->outputs[0];
331 assert(output_id != XNN_INVALID_VALUE_ID);
332 assert(output_id < num_blobs);
333
334 const struct xnn_blob* input1_blob = blobs + input1_id;
335 const void* input1_data = input1_blob->data;
336 assert(input1_data != NULL);
337
338 const struct xnn_blob* input2_blob = blobs + input2_id;
339 const void* input2_data = input2_blob->data;
340 assert(input2_data != NULL);
341
342 const struct xnn_blob* input3_blob = blobs + input3_id;
343 const void* input3_data = input3_blob->data;
344 assert(input3_data != NULL);
345
346 const struct xnn_blob* output_blob = blobs + output_id;
347 void* output_data = output_blob->data;
348 assert(output_data != NULL);
349
350 enum xnn_status status;
351
352 status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
353 if (status != xnn_status_success) {
354 return status;
355 }
356 status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
357 if (status != xnn_status_success) {
358 return status;
359 }
360 return setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
361 }
362
setup_concatenate4_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)363 static enum xnn_status setup_concatenate4_operator(
364 const struct xnn_operator_data* opdata,
365 const struct xnn_blob* blobs,
366 size_t num_blobs,
367 pthreadpool_t threadpool)
368 {
369 const uint32_t input1_id = opdata->inputs[0];
370 assert(input1_id != XNN_INVALID_VALUE_ID);
371 assert(input1_id < num_blobs);
372
373 const uint32_t input2_id = opdata->inputs[1];
374 assert(input2_id != XNN_INVALID_VALUE_ID);
375 assert(input2_id < num_blobs);
376
377 const uint32_t input3_id = opdata->inputs[2];
378 assert(input3_id != XNN_INVALID_VALUE_ID);
379 assert(input3_id < num_blobs);
380
381 const uint32_t input4_id = opdata->inputs[3];
382 assert(input4_id != XNN_INVALID_VALUE_ID);
383 assert(input4_id < num_blobs);
384
385 const uint32_t output_id = opdata->outputs[0];
386 assert(output_id != XNN_INVALID_VALUE_ID);
387 assert(output_id < num_blobs);
388
389 const struct xnn_blob* input1_blob = blobs + input1_id;
390 const void* input1_data = input1_blob->data;
391 assert(input1_data != NULL);
392
393 const struct xnn_blob* input2_blob = blobs + input2_id;
394 const void* input2_data = input2_blob->data;
395 assert(input2_data != NULL);
396
397 const struct xnn_blob* input3_blob = blobs + input3_id;
398 const void* input3_data = input3_blob->data;
399 assert(input3_data != NULL);
400
401 const struct xnn_blob* input4_blob = blobs + input4_id;
402 const void* input4_data = input4_blob->data;
403 assert(input4_data != NULL);
404
405 const struct xnn_blob* output_blob = blobs + output_id;
406 void* output_data = output_blob->data;
407 assert(output_data != NULL);
408
409 enum xnn_status status;
410
411 status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
412 if (status != xnn_status_success) {
413 return status;
414 }
415 status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
416 if (status != xnn_status_success) {
417 return status;
418 }
419 status = setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
420 if (status != xnn_status_success) {
421 return status;
422 }
423 return setup_concatenate_operator_helper(input4_data, output_data, opdata, 3, threadpool);
424 }
425
check_input_value(xnn_subgraph_t subgraph,size_t axis,uint32_t input_id,uint32_t output_id,size_t nth,enum xnn_node_type node_type)426 enum xnn_status check_input_value(
427 xnn_subgraph_t subgraph,
428 size_t axis,
429 uint32_t input_id,
430 uint32_t output_id,
431 size_t nth,
432 enum xnn_node_type node_type)
433 {
434 enum xnn_status status;
435 if ((status = xnn_subgraph_check_nth_input_node_id(node_type, input_id, subgraph->num_values, nth)) !=
436 xnn_status_success) {
437 return status;
438 }
439
440 const struct xnn_value* input_value = &subgraph->values[input_id];
441 status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
442 if (status != xnn_status_success) {
443 return status;
444 }
445
446 const struct xnn_value* output_value = &subgraph->values[output_id];
447 if (input_value->shape.num_dims != output_value->shape.num_dims) {
448 xnn_log_error(
449 "failed to define %s operator with input %zu ID #%" PRIu32
450 ": mismatch number of dimensions, input %zu has %zu, output has %zu",
451 xnn_node_type_to_string(node_type), nth, input_id, nth, input_value->shape.num_dims,
452 output_value->shape.num_dims);
453 return xnn_status_invalid_parameter;
454 }
455
456 for (size_t i = 0; i < input_value->shape.num_dims; i++) {
457 if (i != axis && input_value->shape.dim[i] != output_value->shape.dim[i]) {
458 xnn_log_error(
459 "failed to define %s operator with input ID #%" PRIu32
460 ": mismatch dimension %zu, input %zu has %zu, output has %zu",
461 xnn_node_type_to_string(node_type), input_id, i, nth, input_value->shape.dim[i], output_value->shape.dim[i]);
462 return xnn_status_invalid_parameter;
463 }
464 }
465
466 status = xnn_subgraph_check_datatype_matches(node_type, input_id, input_value, output_id, output_value);
467 if (status != xnn_status_success) {
468 return status;
469 }
470
471 return xnn_status_success;
472 }
473
474 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
check_input_compute_type(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,const char * nth,enum xnn_node_type node_type)475 enum xnn_status check_input_compute_type(
476 xnn_subgraph_t subgraph,
477 uint32_t input_id,
478 uint32_t output_id,
479 const char* nth,
480 enum xnn_node_type node_type)
481 {
482 const struct xnn_value* input_value = &subgraph->values[input_id];
483 const struct xnn_value* output_value = &subgraph->values[output_id];
484 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
485 xnn_log_error(
486 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
487 ": mismatching quantization zero point across the %s input (%" PRId32 ") and the output (%" PRId32 ")",
488 xnn_node_type_to_string(node_type), input_id, output_id,
489 nth, input_value->quantization.zero_point, output_value->quantization.zero_point);
490 return xnn_status_invalid_parameter;
491 }
492 if (input_value->quantization.scale != output_value->quantization.scale) {
493 xnn_log_error(
494 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
495 ": mismatching quantization scale across the %s input (%.7g) and the output (%.7g)",
496 xnn_node_type_to_string(node_type), input_id, output_id,
497 nth, input_value->quantization.scale, output_value->quantization.scale);
498 return xnn_status_invalid_parameter;
499 }
500 return xnn_status_success;
501 }
502 #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
503
xnn_define_concatenate_n(enum xnn_node_type node_type,xnn_subgraph_t subgraph,size_t axis,size_t num_inputs,uint32_t * input_ids,uint32_t output_id,uint32_t flags)504 enum xnn_status xnn_define_concatenate_n(
505 enum xnn_node_type node_type,
506 xnn_subgraph_t subgraph,
507 size_t axis,
508 size_t num_inputs,
509 uint32_t* input_ids,
510 uint32_t output_id,
511 uint32_t flags)
512 {
513 assert(num_inputs >= 2);
514 assert(num_inputs <= 4);
515
516 enum xnn_status status;
517 if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
518 return status;
519 }
520
521 status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
522 if (status != xnn_status_success) {
523 return status;
524 }
525
526 const struct xnn_value* output_value = &subgraph->values[output_id];
527
528 status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
529 if (status != xnn_status_success) {
530 return status;
531 }
532
533 if (axis >= output_value->shape.num_dims) {
534 xnn_log_error(
535 "failed to define %s operator with the output ID #%" PRIu32
536 ": axis (%zu) exceeds the number of dimensions (%zu)",
537 xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.num_dims);
538 return xnn_status_invalid_parameter;
539 }
540
541 for (size_t i = 0; i < num_inputs; i++) {
542 status = check_input_value(subgraph, axis, input_ids[i], output_id, i+1, node_type);
543 if (status != xnn_status_success) {
544 return status;
545 }
546 }
547
548 size_t input_axis_dimensions_sum = 0;
549 for (size_t i = 0; i < num_inputs; i++) {
550 const struct xnn_value* input_value = &subgraph->values[input_ids[i]];
551 input_axis_dimensions_sum += input_value->shape.dim[axis];
552 }
553
554 if (output_value->shape.dim[axis] != input_axis_dimensions_sum) {
555 xnn_log_error(
556 "failed to define %s operator with output ID #%" PRIu32
557 ": mismatch axis dimension %zu, output has %zu, sum of input dimensions is %zu",
558 xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.dim[axis], input_axis_dimensions_sum);
559 return xnn_status_invalid_parameter;
560 }
561
562 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
563 switch (output_value->datatype) {
564 #ifndef XNN_NO_F16_OPERATORS
565 case xnn_datatype_fp16:
566 compute_type = xnn_compute_type_fp16;
567 break;
568 #endif // !defined(XNN_NO_F16_OPERATORS)
569 case xnn_datatype_fp32:
570 compute_type = xnn_compute_type_fp32;
571 break;
572 #ifndef XNN_NO_QS8_OPERATORS
573 case xnn_datatype_qint8:
574 compute_type = xnn_compute_type_qs8;
575 break;
576 #endif // !defined(XNN_NO_QS8_OPERATORS)
577 #ifndef XNN_NO_QU8_OPERATORS
578 case xnn_datatype_quint8:
579 compute_type = xnn_compute_type_qu8;
580 break;
581 #endif // !defined(XNN_NO_QU8_OPERATORS)
582 default:
583 xnn_log_error(
584 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
585 xnn_node_type_to_string(node_type), output_id,
586 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
587 return xnn_status_invalid_parameter;
588 }
589
590 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
591 if (compute_type == xnn_compute_type_qs8 || compute_type == xnn_compute_type_qu8) {
592 check_input_compute_type(subgraph, input_ids[0], output_id, "first", node_type);
593 check_input_compute_type(subgraph, input_ids[1], output_id, "second", node_type);
594 }
595 if (num_inputs > 2) {
596 check_input_compute_type(subgraph, input_ids[2], output_id, "third", node_type);
597 }
598 if (num_inputs > 3) {
599 check_input_compute_type(subgraph, input_ids[3], output_id, "fourth", node_type);
600 }
601 #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
602
603 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
604 if (node == NULL) {
605 return xnn_status_out_of_memory;
606 }
607
608 node->params.concatenate.axis = axis;
609 node->type = node_type;
610 node->compute_type = compute_type;
611 node->num_inputs = num_inputs;
612 node->inputs[0] = input_ids[0];
613 node->inputs[1] = input_ids[1];
614 node->num_outputs = 1;
615 node->outputs[0] = output_id;
616 node->flags = flags;
617
618 switch (num_inputs) {
619 case 2:
620 node->create = create_concatenate2_operator;
621 node->setup = setup_concatenate2_operator;
622 break;
623 case 3:
624 node->create = create_concatenate3_operator;
625 node->setup = setup_concatenate3_operator;
626 node->inputs[2] = input_ids[2];
627 break;
628 case 4:
629 node->create = create_concatenate4_operator;
630 node->setup = setup_concatenate4_operator;
631 node->inputs[2] = input_ids[2];
632 node->inputs[3] = input_ids[3];
633 break;
634 default:
635 XNN_UNREACHABLE;
636 }
637
638 return xnn_status_success;
639 }
640
xnn_define_concatenate2(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)641 enum xnn_status xnn_define_concatenate2(
642 xnn_subgraph_t subgraph,
643 size_t axis,
644 uint32_t input1_id,
645 uint32_t input2_id,
646 uint32_t output_id,
647 uint32_t flags)
648 {
649 uint32_t input_ids[2] = { input1_id, input2_id };
650 return xnn_define_concatenate_n(
651 xnn_node_type_concatenate2, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
652 }
653
xnn_define_concatenate3(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t output_id,uint32_t flags)654 enum xnn_status xnn_define_concatenate3(
655 xnn_subgraph_t subgraph,
656 size_t axis,
657 uint32_t input1_id,
658 uint32_t input2_id,
659 uint32_t input3_id,
660 uint32_t output_id,
661 uint32_t flags)
662 {
663 uint32_t input_ids[3] = { input1_id, input2_id, input3_id };
664 return xnn_define_concatenate_n(
665 xnn_node_type_concatenate3, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
666 }
667
xnn_define_concatenate4(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t input4_id,uint32_t output_id,uint32_t flags)668 enum xnn_status xnn_define_concatenate4(
669 xnn_subgraph_t subgraph,
670 size_t axis,
671 uint32_t input1_id,
672 uint32_t input2_id,
673 uint32_t input3_id,
674 uint32_t input4_id,
675 uint32_t output_id,
676 uint32_t flags)
677 {
678 uint32_t input_ids[4] = { input1_id, input2_id, input3_id, input4_id };
679 return xnn_define_concatenate_n(
680 xnn_node_type_concatenate4, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
681 }
682