1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <stdint.h> // For size_t.
8
9 #include <xnnpack.h>
10 #include <xnnpack/log.h>
11 #include <xnnpack/operator.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 #include <xnnpack/subgraph-validation.h>
15
calculate_batch_size(const struct xnn_value * input,size_t axis)16 static size_t calculate_batch_size(const struct xnn_value* input, size_t axis)
17 {
18 size_t batch_size = 1;
19 for (size_t i = 0; i < axis; i++) {
20 batch_size *= input->shape.dim[i];
21 }
22 return batch_size;
23 }
24
calculate_input_stride(const struct xnn_value * input,size_t axis)25 static size_t calculate_input_stride(const struct xnn_value* input, size_t axis)
26 {
27 size_t input_stride = 1;
28 for (size_t i = axis; i < input->shape.num_dims; i++) {
29 input_stride *= input->shape.dim[i];
30 }
31 return input_stride;
32 }
33
create_even_split_operator_helper(const uint32_t output_id,const struct xnn_node * node,size_t channels,size_t input_stride,size_t output_stride,struct xnn_operator_data * opdata,size_t index)34 static enum xnn_status create_even_split_operator_helper(
35 const uint32_t output_id,
36 const struct xnn_node* node,
37 size_t channels,
38 size_t input_stride,
39 size_t output_stride,
40 struct xnn_operator_data* opdata,
41 size_t index)
42 {
43 if (output_id == XNN_INVALID_VALUE_ID) {
44 // Node's output value has been optimized away, don't even create operator object.
45 return xnn_status_success;
46 }
47
48 switch (node->compute_type) {
49 #ifndef XNN_NO_F16_OPERATORS
50 case xnn_compute_type_fp16: {
51 return xnn_create_copy_nc_x16(
52 channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
53 }
54 #endif // !defined(XNN_NO_F16_OPERATORS)
55 case xnn_compute_type_fp32: {
56 return xnn_create_copy_nc_x32(
57 channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
58 }
59 #ifndef XNN_NO_QS8_OPERATORS
60 case xnn_compute_type_qs8:
61 #endif // !defined(XNN_NO_QS8_OPERATORS)
62 #ifndef XNN_NO_QU8_OPERATORS
63 case xnn_compute_type_qu8:
64 #endif // !defined(XNN_NO_QU8_OPERATORS)
65 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
66 {
67 return xnn_create_copy_nc_x8(
68 channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
69 }
70 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
71 default:
72 XNN_UNREACHABLE;
73 }
74 }
75
create_even_split2_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)76 static enum xnn_status create_even_split2_operator(
77 const struct xnn_node* node,
78 const struct xnn_value* values,
79 size_t num_values,
80 struct xnn_operator_data* opdata,
81 const struct xnn_caches* caches)
82 {
83 assert(node->num_inputs == 1);
84 const uint32_t input_id = node->inputs[0];
85 assert(input_id != XNN_INVALID_VALUE_ID);
86 assert(input_id < num_values);
87
88 assert(node->num_outputs == 2);
89 uint32_t output1_id = node->outputs[0];
90 assert(output1_id != XNN_INVALID_VALUE_ID);
91 assert(output1_id < num_values);
92 if (values[output1_id].type == xnn_value_type_invalid) {
93 output1_id = XNN_INVALID_VALUE_ID;
94 }
95 uint32_t output2_id = node->outputs[1];
96 assert(output2_id != XNN_INVALID_VALUE_ID);
97 assert(output2_id < num_values);
98 if (values[output2_id].type == xnn_value_type_invalid) {
99 output2_id = XNN_INVALID_VALUE_ID;
100 }
101
102 const size_t axis = node->params.even_split.axis;
103 const size_t batch_size = calculate_batch_size(&values[input_id], axis);
104 const size_t input_stride = calculate_input_stride(&values[input_id], axis);
105 assert(input_stride % 2 == 0);
106 const size_t channels = input_stride / 2;
107 const size_t output_stride = channels;
108
109 enum xnn_status status;
110 status = create_even_split_operator_helper(output1_id, node, channels, input_stride, output_stride, opdata, 0);
111 if (status != xnn_status_success) {
112 return status;
113 }
114 status = create_even_split_operator_helper(output2_id, node, channels, input_stride, output_stride, opdata, 1);
115 if (status != xnn_status_success) {
116 return status;
117 }
118
119 opdata->inputs[0] = input_id;
120 opdata->outputs[0] = output1_id;
121 opdata->outputs[1] = output2_id;
122 opdata->batch_size = batch_size;
123
124 return status;
125 }
126
create_even_split3_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)127 static enum xnn_status create_even_split3_operator(
128 const struct xnn_node* node,
129 const struct xnn_value* values,
130 size_t num_values,
131 struct xnn_operator_data* opdata,
132 const struct xnn_caches* caches)
133 {
134 assert(node->num_inputs == 1);
135 const uint32_t input_id = node->inputs[0];
136 assert(input_id != XNN_INVALID_VALUE_ID);
137 assert(input_id < num_values);
138
139 assert(node->num_outputs == 3);
140 uint32_t output1_id = node->outputs[0];
141 if (values[output1_id].type == xnn_value_type_invalid) {
142 output1_id = XNN_INVALID_VALUE_ID;
143 }
144 uint32_t output2_id = node->outputs[1];
145 if (values[output2_id].type == xnn_value_type_invalid) {
146 output2_id = XNN_INVALID_VALUE_ID;
147 }
148 uint32_t output3_id = node->outputs[2];
149 if (values[output3_id].type == xnn_value_type_invalid) {
150 output3_id = XNN_INVALID_VALUE_ID;
151 }
152
153 const size_t axis = node->params.even_split.axis;
154 const size_t batch_size = calculate_batch_size(&values[input_id], axis);
155 const size_t input_stride = calculate_input_stride(&values[input_id], axis);
156 assert(input_stride % 3 == 0);
157 const size_t channels = input_stride / 3;
158 const size_t output_stride = channels;
159
160 enum xnn_status status;
161 status = create_even_split_operator_helper(output1_id, node, channels, input_stride, output_stride, opdata, 0);
162 if (status != xnn_status_success) {
163 return status;
164 }
165 status = create_even_split_operator_helper(output2_id, node, channels, input_stride, output_stride, opdata, 1);
166 if (status != xnn_status_success) {
167 return status;
168 }
169 status = create_even_split_operator_helper(output3_id, node, channels, input_stride, output_stride, opdata, 2);
170 if (status != xnn_status_success) {
171 return status;
172 }
173
174 opdata->inputs[0] = input_id;
175 opdata->outputs[0] = output1_id;
176 opdata->outputs[1] = output2_id;
177 opdata->outputs[2] = output3_id;
178 opdata->batch_size = batch_size;
179
180 return status;
181 }
182
create_even_split4_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)183 static enum xnn_status create_even_split4_operator(
184 const struct xnn_node* node,
185 const struct xnn_value* values,
186 size_t num_values,
187 struct xnn_operator_data* opdata,
188 const struct xnn_caches* caches)
189 {
190 assert(node->num_inputs == 1);
191 const uint32_t input_id = node->inputs[0];
192 assert(input_id != XNN_INVALID_VALUE_ID);
193 assert(input_id < num_values);
194
195 assert(node->num_outputs == 4);
196 uint32_t output1_id = node->outputs[0];
197 if (values[output1_id].type == xnn_value_type_invalid) {
198 output1_id = XNN_INVALID_VALUE_ID;
199 }
200 uint32_t output2_id = node->outputs[1];
201 if (values[output2_id].type == xnn_value_type_invalid) {
202 output2_id = XNN_INVALID_VALUE_ID;
203 }
204 uint32_t output3_id = node->outputs[2];
205 if (values[output3_id].type == xnn_value_type_invalid) {
206 output3_id = XNN_INVALID_VALUE_ID;
207 }
208 uint32_t output4_id = node->outputs[3];
209 if (values[output4_id].type == xnn_value_type_invalid) {
210 output4_id = XNN_INVALID_VALUE_ID;
211 }
212
213 const size_t axis = node->params.even_split.axis;
214 const size_t batch_size = calculate_batch_size(&values[input_id], axis);
215 const size_t input_stride = calculate_input_stride(&values[input_id], axis);
216 assert(input_stride % 4 == 0);
217 const size_t channels = input_stride / 4;
218 const size_t output_stride = channels;
219
220 enum xnn_status status;
221 status = create_even_split_operator_helper(output1_id, node, channels, input_stride, output_stride, opdata, 0);
222 if (status != xnn_status_success) {
223 return status;
224 }
225 status = create_even_split_operator_helper(output2_id, node, channels, input_stride, output_stride, opdata, 1);
226 if (status != xnn_status_success) {
227 return status;
228 }
229 status = create_even_split_operator_helper(output3_id, node, channels, input_stride, output_stride, opdata, 2);
230 if (status != xnn_status_success) {
231 return status;
232 }
233 status = create_even_split_operator_helper(output4_id, node, channels, input_stride, output_stride, opdata, 3);
234 if (status != xnn_status_success) {
235 return status;
236 }
237
238 opdata->inputs[0] = input_id;
239 opdata->outputs[0] = output1_id;
240 opdata->outputs[1] = output2_id;
241 opdata->outputs[2] = output3_id;
242 opdata->outputs[3] = output4_id;
243 opdata->batch_size = batch_size;
244
245 return status;
246 }
247
setup_even_split_operator_helper(const struct xnn_blob * blobs,const uint32_t num_blobs,const struct xnn_operator_data * opdata,size_t index,const size_t channels,const void * input_data,pthreadpool_t threadpool)248 static enum xnn_status setup_even_split_operator_helper(
249 const struct xnn_blob* blobs,
250 const uint32_t num_blobs,
251 const struct xnn_operator_data* opdata,
252 size_t index,
253 const size_t channels,
254 const void* input_data,
255 pthreadpool_t threadpool)
256 {
257 const uint32_t output_id = opdata->outputs[index];
258 if (output_id == XNN_INVALID_VALUE_ID) {
259 assert(opdata->operator_objects[index] == NULL);
260 // output_id was removed during optimization.
261 return xnn_status_success;
262 }
263
264 assert(output_id < num_blobs);
265 const struct xnn_blob* output_blob = blobs + output_id;
266 void* output_data = output_blob->data;
267 assert(output_data != NULL);
268
269 switch (opdata->operator_objects[0]->type) {
270 #ifndef XNN_NO_F16_OPERATORS
271 case xnn_operator_type_copy_nc_x16: {
272 return xnn_setup_copy_nc_x16(
273 opdata->operator_objects[index], opdata->batch_size, (const uint16_t*) input_data + index * channels,
274 output_data, threadpool);
275 }
276 #endif // !defined(XNN_NO_F16_OPERATORS)
277 case xnn_operator_type_copy_nc_x32: {
278 return xnn_setup_copy_nc_x32(
279 opdata->operator_objects[index], opdata->batch_size, (const uint32_t*) input_data + index * channels,
280 output_data, threadpool);
281 }
282 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
283 case xnn_operator_type_copy_nc_x8: {
284 return xnn_setup_copy_nc_x8(
285 opdata->operator_objects[index], opdata->batch_size, (const uint8_t*) input_data + index * channels,
286 output_data, threadpool);
287 }
288 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
289 default:
290 XNN_UNREACHABLE;
291 }
292 }
293
setup_even_split2_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)294 static enum xnn_status setup_even_split2_operator(
295 const struct xnn_operator_data* opdata,
296 const struct xnn_blob* blobs,
297 size_t num_blobs,
298 pthreadpool_t threadpool)
299 {
300 const uint32_t input_id = opdata->inputs[0];
301 assert(input_id != XNN_INVALID_VALUE_ID);
302 assert(input_id < num_blobs);
303
304 const struct xnn_blob* input_blob = blobs + input_id;
305 const void* input_data = input_blob->data;
306 assert(input_data != NULL);
307
308 const size_t channels = opdata->operator_objects[0]->channels;
309 enum xnn_status status = xnn_status_success;
310
311 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 0, channels, input_data, threadpool);
312 if (status != xnn_status_success) {
313 return status;
314 }
315 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 1, channels, input_data, threadpool);
316 if (status != xnn_status_success) {
317 return status;
318 }
319
320 return status;
321 }
322
setup_even_split3_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)323 static enum xnn_status setup_even_split3_operator(
324 const struct xnn_operator_data* opdata,
325 const struct xnn_blob* blobs,
326 size_t num_blobs,
327 pthreadpool_t
328 threadpool)
329 {
330 const uint32_t input_id = opdata->inputs[0];
331 assert(input_id != XNN_INVALID_VALUE_ID);
332 assert(input_id < num_blobs);
333
334 const struct xnn_blob* input_blob = blobs + input_id;
335 const void* input_data = input_blob->data;
336 assert(input_data != NULL);
337
338 const size_t channels = opdata->operator_objects[0]->channels;
339 enum xnn_status status = xnn_status_success;
340
341 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 0, channels, input_data, threadpool);
342 if (status != xnn_status_success) {
343 return status;
344 }
345 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 1, channels, input_data, threadpool);
346 if (status != xnn_status_success) {
347 return status;
348 }
349 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 2, channels, input_data, threadpool);
350 if (status != xnn_status_success) {
351 return status;
352 }
353
354 return status;
355 }
356
setup_even_split4_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)357 static enum xnn_status setup_even_split4_operator(
358 const struct xnn_operator_data* opdata,
359 const struct xnn_blob* blobs,
360 size_t num_blobs,
361 pthreadpool_t
362 threadpool)
363 {
364 const uint32_t input_id = opdata->inputs[0];
365 assert(input_id != XNN_INVALID_VALUE_ID);
366 assert(input_id < num_blobs);
367
368 const struct xnn_blob* input_blob = blobs + input_id;
369 const void* input_data = input_blob->data;
370 assert(input_data != NULL);
371
372 const size_t channels = opdata->operator_objects[0]->channels;
373 enum xnn_status status = xnn_status_success;
374
375 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 0, channels, input_data, threadpool);
376 if (status != xnn_status_success) {
377 return status;
378 }
379 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 1, channels, input_data, threadpool);
380 if (status != xnn_status_success) {
381 return status;
382 }
383 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 2, channels, input_data, threadpool);
384 if (status != xnn_status_success) {
385 return status;
386 }
387 status = setup_even_split_operator_helper(blobs, num_blobs, opdata, 3, channels, input_data, threadpool);
388 if (status != xnn_status_success) {
389 return status;
390 }
391
392 return status;
393 }
394
check_output_value(xnn_subgraph_t subgraph,size_t split_dim,uint32_t input_id,uint32_t output_id,const char * nth,enum xnn_node_type node_type)395 enum xnn_status check_output_value(
396 xnn_subgraph_t subgraph,
397 size_t split_dim,
398 uint32_t input_id,
399 uint32_t output_id,
400 const char* nth,
401 enum xnn_node_type node_type)
402 {
403 const struct xnn_value* input_value = &subgraph->values[input_id];
404 const struct xnn_value* output_value = &subgraph->values[output_id];
405 enum xnn_status status;
406
407 status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
408 if (status != xnn_status_success) {
409 return status;
410 }
411
412 status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
413 if (status != xnn_status_success) {
414 return status;
415 }
416
417 if (input_value->shape.num_dims != output_value->shape.num_dims) {
418 xnn_log_error(
419 "failed to define %s operator with %s output ID #%" PRIu32
420 ": mismatch number of dimensions, input has %zu, %s output has %zu",
421 xnn_node_type_to_string(node_type), nth, output_id, input_value->shape.num_dims,
422 nth, output_value->shape.num_dims);
423 return xnn_status_invalid_parameter;
424 }
425
426 for (size_t i = 0; i < input_value->shape.num_dims; i++) {
427 if (i != split_dim && input_value->shape.dim[i] != output_value->shape.dim[i]) {
428 xnn_log_error(
429 "failed to define %s operator with %s output ID #%" PRIu32
430 ": mismatch dimension %zu, %s output has %zu, input has %zu",
431 xnn_node_type_to_string(node_type), nth, output_id, i, nth, output_value->shape.dim[i],
432 input_value->shape.dim[i]);
433 return xnn_status_invalid_parameter;
434 }
435 }
436
437 status = xnn_subgraph_check_datatype_matches(node_type, input_id, input_value, output_id, output_value);
438 if (status != xnn_status_success) {
439 return status;
440 }
441
442 return xnn_status_success;
443 }
444
445 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
check_output_compute_type(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,const char * nth,enum xnn_node_type node_type)446 enum xnn_status check_output_compute_type(
447 xnn_subgraph_t subgraph,
448 uint32_t input_id,
449 uint32_t output_id,
450 const char* nth,
451 enum xnn_node_type node_type)
452 {
453 const struct xnn_value* input_value = &subgraph->values[input_id];
454 const struct xnn_value* output_value = &subgraph->values[output_id];
455 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
456 xnn_log_error(
457 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
458 ": mismatching quantization zero point across the input (%" PRId32 ") and the %s output (%" PRId32 ")",
459 xnn_node_type_to_string(node_type), input_id, output_id,
460 input_value->quantization.zero_point, nth, output_value->quantization.zero_point);
461 return xnn_status_invalid_parameter;
462 }
463 if (input_value->quantization.scale != output_value->quantization.scale) {
464 xnn_log_error(
465 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
466 ": mismatching quantization scale across the input (%.7g) and the %s output (%.7g)",
467 xnn_node_type_to_string(node_type), input_id, output_id, input_value->quantization.scale,
468 nth, output_value->quantization.scale);
469 return xnn_status_invalid_parameter;
470 }
471 return xnn_status_success;
472 }
473 #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
474
xnn_define_even_split_n(enum xnn_node_type node_type,xnn_subgraph_t subgraph,size_t split_dim,uint32_t input_id,size_t num_outputs,const uint32_t * output_ids,uint32_t flags)475 enum xnn_status xnn_define_even_split_n(
476 enum xnn_node_type node_type,
477 xnn_subgraph_t subgraph,
478 size_t split_dim,
479 uint32_t input_id,
480 size_t num_outputs,
481 const uint32_t* output_ids,
482 uint32_t flags)
483 {
484 assert(num_outputs > 1);
485 assert(num_outputs < 5);
486
487 enum xnn_status status;
488 if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
489 return status;
490 }
491
492 if ((status = xnn_subgraph_check_input_node_id(node_type, input_id, subgraph->num_values)) != xnn_status_success) {
493 return status;
494 }
495
496 const struct xnn_value* input_value = &subgraph->values[input_id];
497 status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
498 if (status != xnn_status_success) {
499 return status;
500 }
501
502 check_output_value(subgraph, split_dim, input_id, output_ids[0], "first", node_type);
503 check_output_value(subgraph, split_dim, input_id, output_ids[1], "second", node_type);
504
505 if (num_outputs > 2) {
506 check_output_value(subgraph, split_dim, input_id, output_ids[2], "third", node_type);
507 }
508 if (num_outputs > 3) {
509 check_output_value(subgraph, split_dim, input_id, output_ids[3], "fourth", node_type);
510 }
511
512 // Check that the split dimension can be evenly split into outputs.
513 if (split_dim >= input_value->shape.num_dims) {
514 xnn_log_error(
515 "failed to define %s operator with the input ID #%" PRIu32
516 ": split dimension (%zu) exceeds the number of dimensions (%zu)",
517 xnn_node_type_to_string(node_type), input_id, split_dim, input_value->shape.num_dims);
518 return xnn_status_invalid_parameter;
519 }
520
521 if (input_value->shape.dim[split_dim] % num_outputs != 0) {
522 xnn_log_error(
523 "failed to define %s operator with the input ID #%" PRIu32
524 ": split dimension %zu has value %zu which cannot be evenly split into %zu",
525 xnn_node_type_to_string(node_type), input_id, split_dim, input_value->shape.dim[split_dim], num_outputs);
526 return xnn_status_invalid_parameter;
527 }
528
529 // Check that the split dimensions of output add up;
530 size_t output_dimensions_sum = 0;
531 for (size_t i = 0; i < num_outputs; i++) {
532 const struct xnn_value* output_value = &subgraph->values[output_ids[i]];
533 output_dimensions_sum += output_value->shape.dim[split_dim];
534 }
535
536 if (output_dimensions_sum != input_value->shape.dim[split_dim]) {
537 xnn_log_error(
538 "failed to define %s operator with the input ID #%" PRIu32
539 ": input split dimension value (%zu) does not match the sum of output split dimensions value %zu",
540 xnn_node_type_to_string(node_type), input_id, input_value->shape.dim[split_dim], output_dimensions_sum);
541 return xnn_status_invalid_parameter;
542 }
543
544 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
545 switch (input_value->datatype) {
546 #ifndef XNN_NO_F16_OPERATORS
547 case xnn_datatype_fp16:
548 compute_type = xnn_compute_type_fp16;
549 break;
550 #endif // !defined(XNN_NO_F16_OPERATORS)
551 case xnn_datatype_fp32:
552 compute_type = xnn_compute_type_fp32;
553 break;
554 #ifndef XNN_NO_QS8_OPERATORS
555 case xnn_datatype_qint8:
556 compute_type = xnn_compute_type_qs8;
557 break;
558 #endif // !defined(XNN_NO_QS8_OPERATORS)
559 #ifndef XNN_NO_QU8_OPERATORS
560 case xnn_datatype_quint8:
561 compute_type = xnn_compute_type_qu8;
562 break;
563 #endif // !defined(XNN_NO_QU8_OPERATORS)
564 default:
565 xnn_log_error(
566 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
567 xnn_node_type_to_string(node_type), input_id, xnn_datatype_to_string(input_value->datatype),
568 input_value->datatype);
569 return xnn_status_invalid_parameter;
570 }
571
572 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
573 if (compute_type == xnn_compute_type_qs8 || compute_type == xnn_compute_type_qu8) {
574 check_output_compute_type(subgraph, input_id, output_ids[0], "first", node_type);
575 check_output_compute_type(subgraph, input_id, output_ids[1], "second", node_type);
576 if (num_outputs > 2) {
577 check_output_compute_type(subgraph, input_id, output_ids[2], "third", node_type);
578 }
579 if (num_outputs > 3) {
580 check_output_compute_type(subgraph, input_id, output_ids[3], "fourth", node_type);
581 }
582 }
583 #endif // !defined( XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
584
585
586 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
587 if (node == NULL) {
588 return xnn_status_out_of_memory;
589 }
590
591 node->params.even_split.axis = split_dim;
592 node->type = node_type;
593 node->compute_type = compute_type;
594 node->num_inputs = 1;
595 node->inputs[0] = input_id;
596 node->num_outputs = num_outputs;
597 node->outputs[0] = output_ids[0];
598 node->outputs[1] = output_ids[1];
599 switch (num_outputs) {
600 case 2:
601 node->create = create_even_split2_operator;
602 node->setup = setup_even_split2_operator;
603 break;
604 case 3:
605 node->outputs[2] = output_ids[2];
606 node->create = create_even_split3_operator;
607 node->setup = setup_even_split3_operator;
608 break;
609 case 4:
610 node->outputs[2] = output_ids[2];
611 node->outputs[3] = output_ids[3];
612 node->create = create_even_split4_operator;
613 node->setup = setup_even_split4_operator;
614 break;
615 default:
616 XNN_UNREACHABLE;
617 }
618 node->flags = flags;
619
620 return xnn_status_success;
621 };
622
xnn_define_even_split2(xnn_subgraph_t subgraph,size_t split_dim,uint32_t input_id,uint32_t output1_id,uint32_t output2_id,uint32_t flags)623 enum xnn_status xnn_define_even_split2(
624 xnn_subgraph_t subgraph,
625 size_t split_dim,
626 uint32_t input_id,
627 uint32_t output1_id,
628 uint32_t output2_id,
629 uint32_t flags)
630 {
631 const uint32_t output_ids[2] = { output1_id, output2_id };
632 return xnn_define_even_split_n(
633 xnn_node_type_even_split2, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags);
634 }
635
xnn_define_even_split3(xnn_subgraph_t subgraph,size_t split_dim,uint32_t input_id,uint32_t output1_id,uint32_t output2_id,uint32_t output3_id,uint32_t flags)636 enum xnn_status xnn_define_even_split3(
637 xnn_subgraph_t subgraph,
638 size_t split_dim,
639 uint32_t input_id,
640 uint32_t output1_id,
641 uint32_t output2_id,
642 uint32_t output3_id,
643 uint32_t flags)
644 {
645 const uint32_t output_ids[3] = { output1_id, output2_id, output3_id };
646 return xnn_define_even_split_n(
647 xnn_node_type_even_split3, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags);
648 }
649
xnn_define_even_split4(xnn_subgraph_t subgraph,size_t split_dim,uint32_t input_id,uint32_t output1_id,uint32_t output2_id,uint32_t output3_id,uint32_t output4_id,uint32_t flags)650 enum xnn_status xnn_define_even_split4(
651 xnn_subgraph_t subgraph,
652 size_t split_dim,
653 uint32_t input_id,
654 uint32_t output1_id,
655 uint32_t output2_id,
656 uint32_t output3_id,
657 uint32_t output4_id,
658 uint32_t flags)
659 {
660 const uint32_t output_ids[4] = { output1_id, output2_id, output3_id, output4_id };
661 return xnn_define_even_split_n(
662 xnn_node_type_even_split4, subgraph, split_dim, input_id, XNN_COUNT_OF(output_ids), output_ids, flags);
663 }
664