xref: /aosp_15_r20/external/XNNPACK/src/subgraph/concatenate.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <stdint.h>
8 
9 #include <xnnpack.h>
10 #include <xnnpack/log.h>
11 #include <xnnpack/operator.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 #include <xnnpack/subgraph-validation.h>
15 
create_concatenate_operator_helper(const struct xnn_node * node,size_t channels,size_t input_stride,size_t output_stride,struct xnn_operator_data * opdata,size_t index)16 static enum xnn_status create_concatenate_operator_helper(
17   const struct xnn_node *node,
18   size_t channels,
19   size_t input_stride,
20   size_t output_stride,
21   struct xnn_operator_data *opdata,
22   size_t index)
23 {
24   switch (node->compute_type) {
25 #ifndef XNN_NO_F16_OPERATORS
26     case xnn_compute_type_fp16: {
27       return xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
28     }
29 #endif  // !defined(XNN_NO_F16_OPERATORS)
30     case xnn_compute_type_fp32: {
31       return xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
32     }
33 #ifndef XNN_NO_QS8_OPERATORS
34     case xnn_compute_type_qs8:
35 #endif  // !defined(XNN_NO_QS8_OPERATORS)
36 #ifndef XNN_NO_QU8_OPERATORS
37     case xnn_compute_type_qu8:
38 #endif  // !defined(XNN_NO_QU8_OPERATORS)
39 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
40     {
41       return xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
42     }
43 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
44     default:
45       XNN_UNREACHABLE;
46   }
47 }
48 
create_concatenate2_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)49 static enum xnn_status create_concatenate2_operator(
50   const struct xnn_node* node,
51   const struct xnn_value* values,
52   size_t num_values,
53   struct xnn_operator_data* opdata,
54   const struct xnn_caches* caches)
55 {
56   assert(node->num_inputs == 2);
57   const uint32_t input1_id = node->inputs[0];
58   assert(input1_id != XNN_INVALID_VALUE_ID);
59   assert(input1_id < num_values);
60   const uint32_t input2_id = node->inputs[1];
61   assert(input2_id != XNN_INVALID_VALUE_ID);
62   assert(input2_id < num_values);
63 
64   assert(node->num_outputs == 1);
65   const uint32_t output_id = node->outputs[0];
66   assert(output_id != XNN_INVALID_VALUE_ID);
67   assert(output_id < num_values);
68 
69   const size_t axis = node->params.concatenate.axis;
70   size_t batch_size = 1, channels_1 = 1, channels_2 = 1;
71   for (size_t i = 0; i < axis; i++) {
72     batch_size *= values[output_id].shape.dim[i];
73   }
74 
75   for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
76     channels_1 *= values[input1_id].shape.dim[i];
77     channels_2 *= values[input2_id].shape.dim[i];
78   }
79   const size_t output_stride = channels_1 + channels_2;
80 
81   enum xnn_status status;
82   status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
83   if (status != xnn_status_success) {
84     return status;
85   }
86   status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
87   if (status != xnn_status_success) {
88     return status;
89   }
90 
91   opdata->inputs[0] = input1_id;
92   opdata->inputs[1] = input2_id;
93   opdata->outputs[0] = output_id;
94   opdata->batch_size = batch_size;
95 
96   return status;
97 }
98 
create_concatenate3_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)99 static enum xnn_status create_concatenate3_operator(
100   const struct xnn_node* node,
101   const struct xnn_value* values,
102   size_t num_values,
103   struct xnn_operator_data* opdata,
104   const struct xnn_caches* caches)
105 {
106   assert(node->num_inputs == 3);
107   const uint32_t input1_id = node->inputs[0];
108   assert(input1_id != XNN_INVALID_VALUE_ID);
109   assert(input1_id < num_values);
110   const uint32_t input2_id = node->inputs[1];
111   assert(input2_id != XNN_INVALID_VALUE_ID);
112   assert(input2_id < num_values);
113   const uint32_t input3_id = node->inputs[2];
114   assert(input3_id != XNN_INVALID_VALUE_ID);
115   assert(input3_id < num_values);
116 
117   assert(node->num_outputs == 1);
118   const uint32_t output_id = node->outputs[0];
119   assert(output_id != XNN_INVALID_VALUE_ID);
120   assert(output_id < num_values);
121 
122   const size_t axis = node->params.concatenate.axis;
123   size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1;
124   for (size_t i = 0; i < axis; i++) {
125     batch_size *= values[output_id].shape.dim[i];
126   }
127 
128   for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
129     channels_1 *= values[input1_id].shape.dim[i];
130     channels_2 *= values[input2_id].shape.dim[i];
131     channels_3 *= values[input3_id].shape.dim[i];
132   }
133   const size_t output_stride = channels_1 + channels_2 + channels_3;
134 
135   enum xnn_status status;
136   status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
137   if (status != xnn_status_success) {
138     return status;
139   }
140   status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
141   if (status != xnn_status_success) {
142     return status;
143   }
144   status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
145   if (status != xnn_status_success) {
146     return status;
147   }
148 
149   opdata->inputs[0] = input1_id;
150   opdata->inputs[1] = input2_id;
151   opdata->inputs[2] = input3_id;
152   opdata->outputs[0] = output_id;
153   opdata->batch_size = batch_size;
154 
155   return status;
156 }
157 
create_concatenate4_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)158 static enum xnn_status create_concatenate4_operator(
159   const struct xnn_node* node,
160   const struct xnn_value* values,
161   size_t num_values,
162   struct xnn_operator_data* opdata,
163   const struct xnn_caches* caches)
164 {
165   assert(node->num_inputs == 4);
166   const uint32_t input1_id = node->inputs[0];
167   assert(input1_id != XNN_INVALID_VALUE_ID);
168   assert(input1_id < num_values);
169   const uint32_t input2_id = node->inputs[1];
170   assert(input2_id != XNN_INVALID_VALUE_ID);
171   assert(input2_id < num_values);
172   const uint32_t input3_id = node->inputs[2];
173   assert(input3_id != XNN_INVALID_VALUE_ID);
174   assert(input3_id < num_values);
175   const uint32_t input4_id = node->inputs[3];
176   assert(input4_id != XNN_INVALID_VALUE_ID);
177   assert(input4_id < num_values);
178 
179   assert(node->num_outputs == 1);
180   const uint32_t output_id = node->outputs[0];
181   assert(output_id != XNN_INVALID_VALUE_ID);
182   assert(output_id < num_values);
183 
184   const size_t axis = node->params.concatenate.axis;
185   size_t batch_size = 1, channels_1 = 1, channels_2 = 1, channels_3 = 1, channels_4 = 1;
186   for (size_t i = 0; i < axis; i++) {
187     batch_size *= values[output_id].shape.dim[i];
188   }
189 
190   for (size_t i = axis; i < values[input1_id].shape.num_dims; i++) {
191     channels_1 *= values[input1_id].shape.dim[i];
192     channels_2 *= values[input2_id].shape.dim[i];
193     channels_3 *= values[input3_id].shape.dim[i];
194     channels_4 *= values[input4_id].shape.dim[i];
195   }
196   const size_t output_stride = channels_1 + channels_2 + channels_3 + channels_4;
197 
198   enum xnn_status status;
199   status = create_concatenate_operator_helper(node, channels_1, channels_1, output_stride, opdata, 0);
200   if (status != xnn_status_success) {
201     return status;
202   }
203   status = create_concatenate_operator_helper(node, channels_2, channels_2, output_stride, opdata, 1);
204   if (status != xnn_status_success) {
205     return status;
206   }
207   status = create_concatenate_operator_helper(node, channels_3, channels_3, output_stride, opdata, 2);
208   if (status != xnn_status_success) {
209     return status;
210   }
211   status = create_concatenate_operator_helper(node, channels_4, channels_4, output_stride, opdata, 3);
212   if (status != xnn_status_success) {
213     return status;
214   }
215 
216   opdata->inputs[0] = input1_id;
217   opdata->inputs[1] = input2_id;
218   opdata->inputs[2] = input3_id;
219   opdata->inputs[3] = input4_id;
220   opdata->outputs[0] = output_id;
221   opdata->batch_size = batch_size;
222 
223   return status;
224 }
225 
setup_concatenate_operator_helper(const void * input_data,void * output_data,const struct xnn_operator_data * opdata,size_t index,pthreadpool_t threadpool)226 static enum xnn_status setup_concatenate_operator_helper(
227   const void* input_data,
228   void* output_data,
229   const struct xnn_operator_data *opdata,
230   size_t index,
231   pthreadpool_t threadpool)
232 {
233   // The output pointer of this operator is the sum of all channels of the earlier operators.
234   size_t channels = 0;
235   for (size_t i = 0; i < index; i++) {
236     channels += opdata->operator_objects[i]->channels;
237   }
238 
239   switch (opdata->operator_objects[index]->type) {
240 #ifndef XNN_NO_F16_OPERATORS
241     case xnn_operator_type_copy_nc_x16: {
242       return xnn_setup_copy_nc_x16(
243         opdata->operator_objects[index],
244         opdata->batch_size,
245         input_data,
246         (uint16_t*) output_data + channels,
247         threadpool);
248     }
249 #endif  // !defined(XNN_NO_F16_OPERATORS)
250     case xnn_operator_type_copy_nc_x32: {
251       return xnn_setup_copy_nc_x32(
252         opdata->operator_objects[index],
253         opdata->batch_size,
254         input_data,
255         (uint32_t*) output_data + channels,
256         threadpool);
257     }
258 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
259     case xnn_operator_type_copy_nc_x8: {
260       return xnn_setup_copy_nc_x8(
261         opdata->operator_objects[index],
262         opdata->batch_size,
263         input_data,
264         (uint8_t*) output_data + channels,
265         threadpool);
266     }
267 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
268     default:
269       XNN_UNREACHABLE;
270   }
271 }
272 
setup_concatenate2_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)273 static enum xnn_status setup_concatenate2_operator(
274   const struct xnn_operator_data* opdata,
275   const struct xnn_blob* blobs,
276   size_t num_blobs,
277   pthreadpool_t threadpool)
278 {
279   const uint32_t input1_id = opdata->inputs[0];
280   assert(input1_id != XNN_INVALID_VALUE_ID);
281   assert(input1_id < num_blobs);
282 
283   const uint32_t input2_id = opdata->inputs[1];
284   assert(input2_id != XNN_INVALID_VALUE_ID);
285   assert(input2_id < num_blobs);
286 
287   const uint32_t output_id = opdata->outputs[0];
288   assert(output_id != XNN_INVALID_VALUE_ID);
289   assert(output_id < num_blobs);
290 
291   const struct xnn_blob* input1_blob = blobs + input1_id;
292   const void* input1_data = input1_blob->data;
293   assert(input1_data != NULL);
294 
295   const struct xnn_blob* input2_blob = blobs + input2_id;
296   const void* input2_data = input2_blob->data;
297   assert(input2_data != NULL);
298 
299   const struct xnn_blob* output_blob = blobs + output_id;
300   void* output_data = output_blob->data;
301   assert(output_data != NULL);
302 
303   enum xnn_status status;
304 
305   status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
306   if (status != xnn_status_success) {
307     return status;
308   }
309   return setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
310 }
311 
setup_concatenate3_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)312 static enum xnn_status setup_concatenate3_operator(
313   const struct xnn_operator_data* opdata,
314   const struct xnn_blob* blobs,
315   size_t num_blobs,
316   pthreadpool_t threadpool)
317 {
318   const uint32_t input1_id = opdata->inputs[0];
319   assert(input1_id != XNN_INVALID_VALUE_ID);
320   assert(input1_id < num_blobs);
321 
322   const uint32_t input2_id = opdata->inputs[1];
323   assert(input2_id != XNN_INVALID_VALUE_ID);
324   assert(input2_id < num_blobs);
325 
326   const uint32_t input3_id = opdata->inputs[2];
327   assert(input3_id != XNN_INVALID_VALUE_ID);
328   assert(input3_id < num_blobs);
329 
330   const uint32_t output_id = opdata->outputs[0];
331   assert(output_id != XNN_INVALID_VALUE_ID);
332   assert(output_id < num_blobs);
333 
334   const struct xnn_blob* input1_blob = blobs + input1_id;
335   const void* input1_data = input1_blob->data;
336   assert(input1_data != NULL);
337 
338   const struct xnn_blob* input2_blob = blobs + input2_id;
339   const void* input2_data = input2_blob->data;
340   assert(input2_data != NULL);
341 
342   const struct xnn_blob* input3_blob = blobs + input3_id;
343   const void* input3_data = input3_blob->data;
344   assert(input3_data != NULL);
345 
346   const struct xnn_blob* output_blob = blobs + output_id;
347   void* output_data = output_blob->data;
348   assert(output_data != NULL);
349 
350   enum xnn_status status;
351 
352   status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
353   if (status != xnn_status_success) {
354     return status;
355   }
356   status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
357   if (status != xnn_status_success) {
358     return status;
359   }
360   return setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
361 }
362 
setup_concatenate4_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)363 static enum xnn_status setup_concatenate4_operator(
364   const struct xnn_operator_data* opdata,
365   const struct xnn_blob* blobs,
366   size_t num_blobs,
367   pthreadpool_t threadpool)
368 {
369   const uint32_t input1_id = opdata->inputs[0];
370   assert(input1_id != XNN_INVALID_VALUE_ID);
371   assert(input1_id < num_blobs);
372 
373   const uint32_t input2_id = opdata->inputs[1];
374   assert(input2_id != XNN_INVALID_VALUE_ID);
375   assert(input2_id < num_blobs);
376 
377   const uint32_t input3_id = opdata->inputs[2];
378   assert(input3_id != XNN_INVALID_VALUE_ID);
379   assert(input3_id < num_blobs);
380 
381   const uint32_t input4_id = opdata->inputs[3];
382   assert(input4_id != XNN_INVALID_VALUE_ID);
383   assert(input4_id < num_blobs);
384 
385   const uint32_t output_id = opdata->outputs[0];
386   assert(output_id != XNN_INVALID_VALUE_ID);
387   assert(output_id < num_blobs);
388 
389   const struct xnn_blob* input1_blob = blobs + input1_id;
390   const void* input1_data = input1_blob->data;
391   assert(input1_data != NULL);
392 
393   const struct xnn_blob* input2_blob = blobs + input2_id;
394   const void* input2_data = input2_blob->data;
395   assert(input2_data != NULL);
396 
397   const struct xnn_blob* input3_blob = blobs + input3_id;
398   const void* input3_data = input3_blob->data;
399   assert(input3_data != NULL);
400 
401   const struct xnn_blob* input4_blob = blobs + input4_id;
402   const void* input4_data = input4_blob->data;
403   assert(input4_data != NULL);
404 
405   const struct xnn_blob* output_blob = blobs + output_id;
406   void* output_data = output_blob->data;
407   assert(output_data != NULL);
408 
409   enum xnn_status status;
410 
411   status = setup_concatenate_operator_helper(input1_data, output_data, opdata, 0, threadpool);
412   if (status != xnn_status_success) {
413     return status;
414   }
415   status = setup_concatenate_operator_helper(input2_data, output_data, opdata, 1, threadpool);
416   if (status != xnn_status_success) {
417     return status;
418   }
419   status = setup_concatenate_operator_helper(input3_data, output_data, opdata, 2, threadpool);
420   if (status != xnn_status_success) {
421     return status;
422   }
423   return setup_concatenate_operator_helper(input4_data, output_data, opdata, 3, threadpool);
424 }
425 
check_input_value(xnn_subgraph_t subgraph,size_t axis,uint32_t input_id,uint32_t output_id,size_t nth,enum xnn_node_type node_type)426 enum xnn_status check_input_value(
427   xnn_subgraph_t subgraph,
428   size_t axis,
429   uint32_t input_id,
430   uint32_t output_id,
431   size_t nth,
432   enum xnn_node_type node_type)
433 {
434   enum xnn_status status;
435   if ((status = xnn_subgraph_check_nth_input_node_id(node_type, input_id, subgraph->num_values, nth)) !=
436       xnn_status_success) {
437     return status;
438   }
439 
440   const struct xnn_value* input_value = &subgraph->values[input_id];
441   status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
442   if (status != xnn_status_success) {
443     return status;
444   }
445 
446   const struct xnn_value* output_value = &subgraph->values[output_id];
447   if (input_value->shape.num_dims != output_value->shape.num_dims) {
448     xnn_log_error(
449       "failed to define %s operator with input %zu ID #%" PRIu32
450       ": mismatch number of dimensions, input %zu has %zu, output has %zu",
451       xnn_node_type_to_string(node_type), nth, input_id, nth, input_value->shape.num_dims,
452       output_value->shape.num_dims);
453     return xnn_status_invalid_parameter;
454   }
455 
456   for (size_t i = 0; i < input_value->shape.num_dims; i++) {
457     if (i != axis && input_value->shape.dim[i] != output_value->shape.dim[i]) {
458       xnn_log_error(
459         "failed to define %s operator with input ID #%" PRIu32
460         ": mismatch dimension %zu, input %zu has %zu, output has %zu",
461         xnn_node_type_to_string(node_type), input_id, i, nth, input_value->shape.dim[i], output_value->shape.dim[i]);
462       return xnn_status_invalid_parameter;
463     }
464   }
465 
466   status = xnn_subgraph_check_datatype_matches(node_type, input_id, input_value, output_id, output_value);
467   if (status != xnn_status_success) {
468     return status;
469   }
470 
471   return xnn_status_success;
472 }
473 
474 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
check_input_compute_type(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,const char * nth,enum xnn_node_type node_type)475 enum xnn_status check_input_compute_type(
476   xnn_subgraph_t subgraph,
477   uint32_t input_id,
478   uint32_t output_id,
479   const char* nth,
480   enum xnn_node_type node_type)
481 {
482   const struct xnn_value* input_value = &subgraph->values[input_id];
483   const struct xnn_value* output_value = &subgraph->values[output_id];
484   if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
485     xnn_log_error(
486         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
487         ": mismatching quantization zero point across the %s input (%" PRId32 ") and the output (%" PRId32 ")",
488         xnn_node_type_to_string(node_type), input_id, output_id,
489         nth, input_value->quantization.zero_point, output_value->quantization.zero_point);
490     return xnn_status_invalid_parameter;
491   }
492   if (input_value->quantization.scale != output_value->quantization.scale) {
493     xnn_log_error(
494         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
495         ": mismatching quantization scale across the %s input (%.7g) and the output (%.7g)",
496         xnn_node_type_to_string(node_type), input_id, output_id,
497         nth, input_value->quantization.scale, output_value->quantization.scale);
498     return xnn_status_invalid_parameter;
499   }
500   return xnn_status_success;
501 }
502 #endif  // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
503 
xnn_define_concatenate_n(enum xnn_node_type node_type,xnn_subgraph_t subgraph,size_t axis,size_t num_inputs,uint32_t * input_ids,uint32_t output_id,uint32_t flags)504 enum xnn_status xnn_define_concatenate_n(
505   enum xnn_node_type node_type,
506   xnn_subgraph_t subgraph,
507   size_t axis,
508   size_t num_inputs,
509   uint32_t* input_ids,
510   uint32_t output_id,
511   uint32_t flags)
512 {
513   assert(num_inputs >= 2);
514   assert(num_inputs <= 4);
515 
516   enum xnn_status status;
517   if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
518     return status;
519   }
520 
521   status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
522   if (status != xnn_status_success) {
523     return status;
524   }
525 
526   const struct xnn_value* output_value = &subgraph->values[output_id];
527 
528   status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
529   if (status != xnn_status_success) {
530     return status;
531   }
532 
533   if (axis >= output_value->shape.num_dims) {
534     xnn_log_error(
535       "failed to define %s operator with the output ID #%" PRIu32
536       ": axis (%zu) exceeds the number of dimensions (%zu)",
537       xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.num_dims);
538     return xnn_status_invalid_parameter;
539   }
540 
541   for (size_t i = 0; i < num_inputs; i++) {
542     status = check_input_value(subgraph, axis, input_ids[i], output_id, i+1, node_type);
543     if (status != xnn_status_success) {
544       return status;
545     }
546   }
547 
548   size_t input_axis_dimensions_sum = 0;
549   for (size_t i = 0; i < num_inputs; i++) {
550     const struct xnn_value* input_value = &subgraph->values[input_ids[i]];
551     input_axis_dimensions_sum += input_value->shape.dim[axis];
552   }
553 
554   if (output_value->shape.dim[axis] != input_axis_dimensions_sum) {
555     xnn_log_error(
556       "failed to define %s operator with output ID #%" PRIu32
557       ": mismatch axis dimension %zu, output has %zu, sum of input dimensions is %zu",
558       xnn_node_type_to_string(node_type), output_id, axis, output_value->shape.dim[axis], input_axis_dimensions_sum);
559     return xnn_status_invalid_parameter;
560   }
561 
562   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
563   switch (output_value->datatype) {
564 #ifndef XNN_NO_F16_OPERATORS
565     case xnn_datatype_fp16:
566       compute_type = xnn_compute_type_fp16;
567       break;
568 #endif  // !defined(XNN_NO_F16_OPERATORS)
569     case xnn_datatype_fp32:
570       compute_type = xnn_compute_type_fp32;
571       break;
572 #ifndef XNN_NO_QS8_OPERATORS
573     case xnn_datatype_qint8:
574       compute_type = xnn_compute_type_qs8;
575       break;
576 #endif  // !defined(XNN_NO_QS8_OPERATORS)
577 #ifndef XNN_NO_QU8_OPERATORS
578     case xnn_datatype_quint8:
579       compute_type = xnn_compute_type_qu8;
580       break;
581 #endif  // !defined(XNN_NO_QU8_OPERATORS)
582     default:
583       xnn_log_error(
584         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
585         xnn_node_type_to_string(node_type), output_id,
586         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
587       return xnn_status_invalid_parameter;
588   }
589 
590   #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
591     if (compute_type == xnn_compute_type_qs8 || compute_type == xnn_compute_type_qu8) {
592       check_input_compute_type(subgraph, input_ids[0], output_id, "first", node_type);
593       check_input_compute_type(subgraph, input_ids[1], output_id, "second", node_type);
594     }
595     if (num_inputs > 2) {
596       check_input_compute_type(subgraph, input_ids[2], output_id, "third", node_type);
597     }
598     if (num_inputs > 3) {
599       check_input_compute_type(subgraph, input_ids[3], output_id, "fourth", node_type);
600     }
601   #endif  // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
602 
603   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
604   if (node == NULL) {
605     return xnn_status_out_of_memory;
606   }
607 
608   node->params.concatenate.axis = axis;
609   node->type = node_type;
610   node->compute_type = compute_type;
611   node->num_inputs = num_inputs;
612   node->inputs[0] = input_ids[0];
613   node->inputs[1] = input_ids[1];
614   node->num_outputs = 1;
615   node->outputs[0] = output_id;
616   node->flags = flags;
617 
618   switch (num_inputs) {
619     case 2:
620       node->create = create_concatenate2_operator;
621       node->setup = setup_concatenate2_operator;
622       break;
623     case 3:
624       node->create = create_concatenate3_operator;
625       node->setup = setup_concatenate3_operator;
626       node->inputs[2] = input_ids[2];
627       break;
628     case 4:
629       node->create = create_concatenate4_operator;
630       node->setup = setup_concatenate4_operator;
631       node->inputs[2] = input_ids[2];
632       node->inputs[3] = input_ids[3];
633       break;
634     default:
635       XNN_UNREACHABLE;
636   }
637 
638   return xnn_status_success;
639 }
640 
xnn_define_concatenate2(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)641 enum xnn_status xnn_define_concatenate2(
642   xnn_subgraph_t subgraph,
643   size_t axis,
644   uint32_t input1_id,
645   uint32_t input2_id,
646   uint32_t output_id,
647   uint32_t flags)
648 {
649   uint32_t input_ids[2] = { input1_id, input2_id };
650   return xnn_define_concatenate_n(
651     xnn_node_type_concatenate2, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
652 }
653 
xnn_define_concatenate3(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t output_id,uint32_t flags)654 enum xnn_status xnn_define_concatenate3(
655   xnn_subgraph_t subgraph,
656   size_t axis,
657   uint32_t input1_id,
658   uint32_t input2_id,
659   uint32_t input3_id,
660   uint32_t output_id,
661   uint32_t flags)
662 {
663   uint32_t input_ids[3] = { input1_id, input2_id, input3_id };
664   return xnn_define_concatenate_n(
665     xnn_node_type_concatenate3, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
666 }
667 
xnn_define_concatenate4(xnn_subgraph_t subgraph,size_t axis,uint32_t input1_id,uint32_t input2_id,uint32_t input3_id,uint32_t input4_id,uint32_t output_id,uint32_t flags)668 enum xnn_status xnn_define_concatenate4(
669   xnn_subgraph_t subgraph,
670   size_t axis,
671   uint32_t input1_id,
672   uint32_t input2_id,
673   uint32_t input3_id,
674   uint32_t input4_id,
675   uint32_t output_id,
676   uint32_t flags)
677 {
678   uint32_t input_ids[4] = { input1_id, input2_id, input3_id, input4_id };
679   return xnn_define_concatenate_n(
680     xnn_node_type_concatenate4, subgraph, axis, XNN_COUNT_OF(input_ids), input_ids, output_id, flags);
681 }
682