xref: /aosp_15_r20/external/XNNPACK/test/even-split4.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19 
20 #include <gtest/gtest.h>
21 
22 template <typename T> class EvenSplit4Test : public ::testing::Test {
23 protected:
EvenSplit4Test()24   EvenSplit4Test()
25   {
26     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27     rng = std::mt19937((*random_device)());
28     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30     f32dist = std::uniform_real_distribution<float>();
31     i8dist =
32       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33     u8dist =
34       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36 
37     output1_dims = RandomShape();
38     output2_dims = output1_dims;
39     output3_dims = output1_dims;
40     output4_dims = output1_dims;
41     input_dims = output1_dims;
42     axis = RandomAxis(output1_dims);
43     input_dims[axis] = output1_dims[axis] + output2_dims[axis] + output3_dims[axis] + output4_dims[axis];
44 
45     input = std::vector<T>(NumElements(input_dims));
46     operator_output1 = std::vector<T>(NumElements(output1_dims));
47     operator_output2 = std::vector<T>(NumElements(output2_dims));
48     operator_output3 = std::vector<T>(NumElements(output3_dims));
49     operator_output4 = std::vector<T>(NumElements(output4_dims));
50     subgraph_output1 = std::vector<T>(NumElements(output1_dims));
51     subgraph_output2 = std::vector<T>(NumElements(output2_dims));
52     subgraph_output3 = std::vector<T>(NumElements(output3_dims));
53     subgraph_output4 = std::vector<T>(NumElements(output4_dims));
54 
55     signed_zero_point = i8dist(rng);
56     unsigned_zero_point = u8dist(rng);
57     scale = scale_dist(rng);
58 
59     batch_size = 1;
60     input_stride = 1;
61     for (size_t i = 0; i < axis; i++) {
62       batch_size *= input_dims[i];
63     }
64 
65     for (size_t i = axis; i < input_dims.size(); i++) {
66       input_stride *= input_dims[i];
67     }
68     channels = input_stride / 4;
69   }
70 
RandomShape()71   std::vector<size_t> RandomShape()
72   {
73     std::vector<size_t> dims(shape_dist(rng));
74     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
75     return dims;
76   }
77 
RandomAxis(const std::vector<size_t> & dims)78   size_t RandomAxis(const std::vector<size_t>& dims)
79   {
80     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
81   }
82 
NumElements(const std::vector<size_t> & dims)83   size_t NumElements(const std::vector<size_t>& dims)
84   {
85     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
86   }
87 
88   std::unique_ptr<std::random_device> random_device;
89   std::mt19937 rng;
90   std::uniform_int_distribution<size_t> shape_dist;
91   std::uniform_int_distribution<size_t> dim_dist;
92   std::uniform_real_distribution<float> f32dist;
93   std::uniform_int_distribution<int32_t> i8dist;
94   std::uniform_int_distribution<int32_t> u8dist;
95   std::uniform_real_distribution<float> scale_dist;
96 
97   uint32_t output1_id;
98   uint32_t output2_id;
99   uint32_t output3_id;
100   uint32_t output4_id;
101   uint32_t input_id;
102 
103   std::vector<size_t> output1_dims;
104   std::vector<size_t> output2_dims;
105   std::vector<size_t> output3_dims;
106   std::vector<size_t> output4_dims;
107   std::vector<size_t> input_dims;
108 
109   size_t axis;
110   size_t batch_size;
111   size_t channels;
112   size_t input_stride;
113 
114   int32_t signed_zero_point;
115   int32_t unsigned_zero_point;
116   float scale;
117 
118   std::vector<T> operator_output1;
119   std::vector<T> operator_output2;
120   std::vector<T> operator_output3;
121   std::vector<T> operator_output4;
122   std::vector<T> subgraph_output1;
123   std::vector<T> subgraph_output2;
124   std::vector<T> subgraph_output3;
125   std::vector<T> subgraph_output4;
126   std::vector<T> input;
127 };
128 
129 using EvenSplit4TestQS8 = EvenSplit4Test<int8_t>;
130 using EvenSplit4TestQU8 = EvenSplit4Test<uint8_t>;
131 using EvenSplit4TestF32 = EvenSplit4Test<float>;
132 
TEST_F(EvenSplit4TestQS8,define)133 TEST_F(EvenSplit4TestQS8, define)
134 {
135   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
136 
137   xnn_subgraph_t subgraph = nullptr;
138   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
139   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
140 
141   input_id = XNN_INVALID_NODE_ID;
142   ASSERT_EQ(
143     xnn_status_success,
144     xnn_define_quantized_tensor_value(
145       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
146       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
147   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
148 
149   output1_id = XNN_INVALID_NODE_ID;
150   ASSERT_EQ(
151     xnn_status_success,
152     xnn_define_quantized_tensor_value(
153       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
154       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
155   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
156 
157   output2_id = XNN_INVALID_NODE_ID;
158   ASSERT_EQ(
159     xnn_status_success,
160     xnn_define_quantized_tensor_value(
161       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
162       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
163   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
164 
165   output3_id = XNN_INVALID_NODE_ID;
166   ASSERT_EQ(
167     xnn_status_success,
168     xnn_define_quantized_tensor_value(
169       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
170       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
171   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
172 
173   output4_id = XNN_INVALID_NODE_ID;
174   ASSERT_EQ(
175     xnn_status_success,
176     xnn_define_quantized_tensor_value(
177       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
178       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
179   ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
180 
181   ASSERT_EQ(
182     xnn_status_success,
183     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
184 
185   ASSERT_EQ(subgraph->num_nodes, 1);
186   const struct xnn_node* node = &subgraph->nodes[0];
187   ASSERT_EQ(node->type, xnn_node_type_even_split4);
188   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
189   ASSERT_EQ(node->params.even_split.axis, axis);
190   ASSERT_EQ(node->num_inputs, 1);
191   ASSERT_EQ(node->inputs[0], input_id);
192   ASSERT_EQ(node->num_outputs, 4);
193   ASSERT_EQ(node->outputs[0], output1_id);
194   ASSERT_EQ(node->outputs[1], output2_id);
195   ASSERT_EQ(node->outputs[2], output3_id);
196   ASSERT_EQ(node->outputs[3], output4_id);
197   ASSERT_EQ(node->flags, 0);
198 }
199 
TEST_F(EvenSplit4TestQU8,define)200 TEST_F(EvenSplit4TestQU8, define)
201 {
202   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
203 
204   xnn_subgraph_t subgraph = nullptr;
205   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
206   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
207 
208   input_id = XNN_INVALID_NODE_ID;
209   ASSERT_EQ(
210     xnn_status_success,
211     xnn_define_quantized_tensor_value(
212       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
213       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
214   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
215 
216   output1_id = XNN_INVALID_NODE_ID;
217   ASSERT_EQ(
218     xnn_status_success,
219     xnn_define_quantized_tensor_value(
220       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
221       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
222   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
223 
224   output2_id = XNN_INVALID_NODE_ID;
225   ASSERT_EQ(
226     xnn_status_success,
227     xnn_define_quantized_tensor_value(
228       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
229       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
230   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
231 
232   output3_id = XNN_INVALID_NODE_ID;
233   ASSERT_EQ(
234     xnn_status_success,
235     xnn_define_quantized_tensor_value(
236       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
237       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
238   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
239 
240   output4_id = XNN_INVALID_NODE_ID;
241   ASSERT_EQ(
242     xnn_status_success,
243     xnn_define_quantized_tensor_value(
244       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
245       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
246   ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
247 
248   ASSERT_EQ(
249     xnn_status_success,
250     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
251 
252   ASSERT_EQ(subgraph->num_nodes, 1);
253   const struct xnn_node* node = &subgraph->nodes[0];
254   ASSERT_EQ(node->type, xnn_node_type_even_split4);
255   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
256   ASSERT_EQ(node->params.even_split.axis, axis);
257   ASSERT_EQ(node->num_inputs, 1);
258   ASSERT_EQ(node->inputs[0], input_id);
259   ASSERT_EQ(node->num_outputs, 4);
260   ASSERT_EQ(node->outputs[0], output1_id);
261   ASSERT_EQ(node->outputs[1], output2_id);
262   ASSERT_EQ(node->outputs[2], output3_id);
263   ASSERT_EQ(node->outputs[3], output4_id);
264   ASSERT_EQ(node->flags, 0);
265 }
266 
TEST_F(EvenSplit4TestF32,define)267 TEST_F(EvenSplit4TestF32, define)
268 {
269   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
270 
271   xnn_subgraph_t subgraph = nullptr;
272   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
273   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
274 
275   input_id = XNN_INVALID_NODE_ID;
276   ASSERT_EQ(
277     xnn_status_success, xnn_define_tensor_value(
278                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
279                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
280   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
281 
282   output1_id = XNN_INVALID_NODE_ID;
283   ASSERT_EQ(
284     xnn_status_success, xnn_define_tensor_value(
285                           subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
286                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
287   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
288 
289   output2_id = XNN_INVALID_NODE_ID;
290   ASSERT_EQ(
291     xnn_status_success, xnn_define_tensor_value(
292                           subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
293                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
294   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
295 
296   output3_id = XNN_INVALID_NODE_ID;
297   ASSERT_EQ(
298     xnn_status_success, xnn_define_tensor_value(
299                           subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
300                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
301   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
302 
303   output4_id = XNN_INVALID_NODE_ID;
304   ASSERT_EQ(
305     xnn_status_success, xnn_define_tensor_value(
306                           subgraph, xnn_datatype_fp32, output4_dims.size(), output4_dims.data(), nullptr, 4,
307                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
308   ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
309 
310   ASSERT_EQ(
311     xnn_status_success,
312     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
313 
314   ASSERT_EQ(subgraph->num_nodes, 1);
315   const struct xnn_node* node = &subgraph->nodes[0];
316   ASSERT_EQ(node->type, xnn_node_type_even_split4);
317   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
318   ASSERT_EQ(node->params.even_split.axis, axis);
319   ASSERT_EQ(node->num_inputs, 1);
320   ASSERT_EQ(node->inputs[0], input_id);
321   ASSERT_EQ(node->num_outputs, 4);
322   ASSERT_EQ(node->outputs[0], output1_id);
323   ASSERT_EQ(node->outputs[1], output2_id);
324   ASSERT_EQ(node->outputs[2], output3_id);
325   ASSERT_EQ(node->outputs[3], output4_id);
326   ASSERT_EQ(node->flags, 0);
327 }
328 
TEST_F(EvenSplit4TestQS8,matches_operator_api)329 TEST_F(EvenSplit4TestQS8, matches_operator_api)
330 {
331   std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
332   std::fill(operator_output1.begin(), operator_output1.end(), INT8_C(0xA5));
333   std::fill(operator_output2.begin(), operator_output2.end(), INT8_C(0xA5));
334   std::fill(operator_output3.begin(), operator_output3.end(), INT8_C(0xA5));
335   std::fill(operator_output4.begin(), operator_output4.end(), INT8_C(0xA5));
336   std::fill(subgraph_output1.begin(), subgraph_output1.end(), INT8_C(0xA5));
337   std::fill(subgraph_output2.begin(), subgraph_output2.end(), INT8_C(0xA5));
338   std::fill(subgraph_output3.begin(), subgraph_output3.end(), INT8_C(0xA5));
339   std::fill(subgraph_output4.begin(), subgraph_output4.end(), INT8_C(0xA5));
340 
341   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
342 
343   xnn_operator_t op1 = nullptr;
344   xnn_operator_t op2 = nullptr;
345   xnn_operator_t op3 = nullptr;
346   xnn_operator_t op4 = nullptr;
347 
348   // Call operator API.
349   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
350   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
351   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
352   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
353   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
354   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
355   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op4));
356   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
357 
358   ASSERT_EQ(
359     xnn_status_success,
360     xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
361   ASSERT_EQ(
362     xnn_status_success,
363     xnn_setup_copy_nc_x8(
364       op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
365   ASSERT_EQ(
366     xnn_status_success, xnn_setup_copy_nc_x8(
367                           op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
368                           nullptr /* thread pool */));
369   ASSERT_EQ(
370     xnn_status_success, xnn_setup_copy_nc_x8(
371                           op4, batch_size, (uint8_t*) input.data() + op1->channels * 3, operator_output4.data(),
372                           nullptr /* thread pool */));
373 
374   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
375   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
376   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
377   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
378 
379   // Call subgraph API.
380   xnn_subgraph_t subgraph = nullptr;
381   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
382   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
383 
384   input_id = XNN_INVALID_NODE_ID;
385   ASSERT_EQ(
386     xnn_status_success,
387     xnn_define_quantized_tensor_value(
388       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
389       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
390   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
391 
392   output1_id = XNN_INVALID_NODE_ID;
393   ASSERT_EQ(
394     xnn_status_success,
395     xnn_define_quantized_tensor_value(
396       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
397       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
398   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
399 
400   output2_id = XNN_INVALID_NODE_ID;
401   ASSERT_EQ(
402     xnn_status_success,
403     xnn_define_quantized_tensor_value(
404       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
405       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
406   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
407 
408   output3_id = XNN_INVALID_NODE_ID;
409   ASSERT_EQ(
410     xnn_status_success,
411     xnn_define_quantized_tensor_value(
412       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
413       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
414   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
415 
416   output4_id = XNN_INVALID_NODE_ID;
417   ASSERT_EQ(
418     xnn_status_success,
419     xnn_define_quantized_tensor_value(
420       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
421       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
422   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
423 
424   ASSERT_EQ(
425     xnn_status_success,
426     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
427 
428   xnn_runtime_t runtime = nullptr;
429   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
430   ASSERT_NE(nullptr, runtime);
431   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
432   std::array<xnn_external_value, 5> external = {
433     xnn_external_value{input_id, input.data()},
434     xnn_external_value{output1_id, subgraph_output1.data()},
435     xnn_external_value{output2_id, subgraph_output2.data()},
436     xnn_external_value{output3_id, subgraph_output3.data()},
437     xnn_external_value{output4_id, subgraph_output4.data()},
438   };
439   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
440   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
441 
442   ASSERT_EQ(subgraph_output1, operator_output1);
443   ASSERT_EQ(subgraph_output2, operator_output2);
444   ASSERT_EQ(subgraph_output3, operator_output3);
445   ASSERT_EQ(subgraph_output4, operator_output4);
446 }
447 
TEST_F(EvenSplit4TestQU8,matches_operator_api)448 TEST_F(EvenSplit4TestQU8, matches_operator_api)
449 {
450   std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
451   std::fill(operator_output1.begin(), operator_output1.end(), UINT8_C(0xA5));
452   std::fill(operator_output2.begin(), operator_output2.end(), UINT8_C(0xA5));
453   std::fill(operator_output3.begin(), operator_output3.end(), UINT8_C(0xA5));
454   std::fill(operator_output4.begin(), operator_output4.end(), UINT8_C(0xA5));
455   std::fill(subgraph_output1.begin(), subgraph_output1.end(), UINT8_C(0xA5));
456   std::fill(subgraph_output2.begin(), subgraph_output2.end(), UINT8_C(0xA5));
457   std::fill(subgraph_output3.begin(), subgraph_output3.end(), UINT8_C(0xA5));
458   std::fill(subgraph_output4.begin(), subgraph_output4.end(), UINT8_C(0xA5));
459 
460   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
461 
462   xnn_operator_t op1 = nullptr;
463   xnn_operator_t op2 = nullptr;
464   xnn_operator_t op3 = nullptr;
465   xnn_operator_t op4 = nullptr;
466 
467   // Call operator API.
468   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
469   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
470   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
471   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
472   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
473   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
474   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op4));
475   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
476 
477   ASSERT_EQ(
478     xnn_status_success,
479     xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
480   ASSERT_EQ(
481     xnn_status_success,
482     xnn_setup_copy_nc_x8(
483       op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
484   ASSERT_EQ(
485     xnn_status_success, xnn_setup_copy_nc_x8(
486                           op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
487                           nullptr /* thread pool */));
488   ASSERT_EQ(
489     xnn_status_success, xnn_setup_copy_nc_x8(
490                           op4, batch_size, (uint8_t*) input.data() + op1->channels * 3, operator_output4.data(),
491                           nullptr /* thread pool */));
492 
493   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
494   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
495   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
496   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
497 
498   // Call subgraph API.
499   xnn_subgraph_t subgraph = nullptr;
500   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
501   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
502 
503   input_id = XNN_INVALID_NODE_ID;
504   ASSERT_EQ(
505     xnn_status_success,
506     xnn_define_quantized_tensor_value(
507       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
508       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
509   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
510 
511   output1_id = XNN_INVALID_NODE_ID;
512   ASSERT_EQ(
513     xnn_status_success,
514     xnn_define_quantized_tensor_value(
515       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
516       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
517   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
518 
519   output2_id = XNN_INVALID_NODE_ID;
520   ASSERT_EQ(
521     xnn_status_success,
522     xnn_define_quantized_tensor_value(
523       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
524       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
525   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
526 
527   output3_id = XNN_INVALID_NODE_ID;
528   ASSERT_EQ(
529     xnn_status_success,
530     xnn_define_quantized_tensor_value(
531       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
532       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
533   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
534 
535   output4_id = XNN_INVALID_NODE_ID;
536   ASSERT_EQ(
537     xnn_status_success,
538     xnn_define_quantized_tensor_value(
539       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
540       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
541   ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
542 
543   ASSERT_EQ(
544     xnn_status_success,
545     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
546 
547   xnn_runtime_t runtime = nullptr;
548   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
549   ASSERT_NE(nullptr, runtime);
550   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
551   std::array<xnn_external_value, 5> external = {
552     xnn_external_value{input_id, input.data()},
553     xnn_external_value{output1_id, subgraph_output1.data()},
554     xnn_external_value{output2_id, subgraph_output2.data()},
555     xnn_external_value{output3_id, subgraph_output3.data()},
556     xnn_external_value{output4_id, subgraph_output4.data()},
557   };
558   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
559   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
560 
561   ASSERT_EQ(subgraph_output1, operator_output1);
562   ASSERT_EQ(subgraph_output2, operator_output2);
563   ASSERT_EQ(subgraph_output3, operator_output3);
564   ASSERT_EQ(subgraph_output4, operator_output4);
565 }
566 
TEST_F(EvenSplit4TestF32,matches_operator_api)567 TEST_F(EvenSplit4TestF32, matches_operator_api)
568 {
569   std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
570   std::fill(operator_output1.begin(), operator_output1.end(), std::nanf(""));
571   std::fill(operator_output2.begin(), operator_output2.end(), std::nanf(""));
572   std::fill(operator_output3.begin(), operator_output3.end(), std::nanf(""));
573   std::fill(operator_output4.begin(), operator_output4.end(), std::nanf(""));
574   std::fill(subgraph_output1.begin(), subgraph_output1.end(), std::nanf(""));
575   std::fill(subgraph_output2.begin(), subgraph_output2.end(), std::nanf(""));
576   std::fill(subgraph_output3.begin(), subgraph_output3.end(), std::nanf(""));
577   std::fill(subgraph_output4.begin(), subgraph_output4.end(), std::nanf(""));
578 
579   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
580 
581   xnn_operator_t op1 = nullptr;
582   xnn_operator_t op2 = nullptr;
583   xnn_operator_t op3 = nullptr;
584   xnn_operator_t op4 = nullptr;
585 
586   // Call operator API.
587   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op1));
588   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
589   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op2));
590   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
591   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op3));
592   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
593   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op4));
594   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
595 
596   ASSERT_EQ(
597     xnn_status_success,
598     xnn_setup_copy_nc_x32(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
599   ASSERT_EQ(
600     xnn_status_success,
601     xnn_setup_copy_nc_x32(
602       op2, batch_size, (uint32_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
603   ASSERT_EQ(
604     xnn_status_success, xnn_setup_copy_nc_x32(
605                           op3, batch_size, (uint32_t*) input.data() + op1->channels * 2, operator_output3.data(),
606                           nullptr /* thread pool */));
607   ASSERT_EQ(
608     xnn_status_success, xnn_setup_copy_nc_x32(
609                           op4, batch_size, (uint32_t*) input.data() + op1->channels * 3, operator_output4.data(),
610                           nullptr /* thread pool */));
611 
612   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
613   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
614   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
615   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
616 
617   // Call subgraph API.
618   xnn_subgraph_t subgraph = nullptr;
619   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
620   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
621 
622   input_id = XNN_INVALID_NODE_ID;
623   ASSERT_EQ(
624     xnn_status_success, xnn_define_tensor_value(
625                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
626                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
627   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
628 
629   output1_id = XNN_INVALID_NODE_ID;
630   ASSERT_EQ(
631     xnn_status_success, xnn_define_tensor_value(
632                           subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
633                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
634   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
635 
636   output2_id = XNN_INVALID_NODE_ID;
637   ASSERT_EQ(
638     xnn_status_success, xnn_define_tensor_value(
639                           subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
640                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
641   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
642 
643   output3_id = XNN_INVALID_NODE_ID;
644   ASSERT_EQ(
645     xnn_status_success, xnn_define_tensor_value(
646                           subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
647                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
648   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
649 
650   output4_id = XNN_INVALID_NODE_ID;
651   ASSERT_EQ(
652     xnn_status_success, xnn_define_tensor_value(
653                           subgraph, xnn_datatype_fp32, output4_dims.size(), output4_dims.data(), nullptr, 4,
654                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
655   ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
656 
657   ASSERT_EQ(
658     xnn_status_success,
659     xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
660 
661   xnn_runtime_t runtime = nullptr;
662   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
663   ASSERT_NE(nullptr, runtime);
664   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
665   std::array<xnn_external_value, 5> external = {
666     xnn_external_value{input_id, input.data()},
667     xnn_external_value{output1_id, subgraph_output1.data()},
668     xnn_external_value{output2_id, subgraph_output2.data()},
669     xnn_external_value{output3_id, subgraph_output3.data()},
670     xnn_external_value{output4_id, subgraph_output4.data()},
671   };
672   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
673   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
674 
675   ASSERT_EQ(subgraph_output1, operator_output1);
676   ASSERT_EQ(subgraph_output2, operator_output2);
677   ASSERT_EQ(subgraph_output3, operator_output3);
678   ASSERT_EQ(subgraph_output4, operator_output4);
679 }
680