xref: /aosp_15_r20/external/XNNPACK/test/even-split3.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19 
20 #include <gtest/gtest.h>
21 
22 template <typename T> class EvenSplit3Test : public ::testing::Test {
23 protected:
EvenSplit3Test()24   EvenSplit3Test()
25   {
26     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27     rng = std::mt19937((*random_device)());
28     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30     f32dist = std::uniform_real_distribution<float>();
31     i8dist =
32       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33     u8dist =
34       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36 
37     output1_dims = RandomShape();
38     output2_dims = output1_dims;
39     output3_dims = output1_dims;
40     input_dims = output1_dims;
41     axis = RandomAxis(output1_dims);
42     input_dims[axis] = output1_dims[axis] + output2_dims[axis] + output3_dims[axis];
43 
44     input = std::vector<T>(NumElements(input_dims));
45     operator_output1 = std::vector<T>(NumElements(output1_dims));
46     operator_output2 = std::vector<T>(NumElements(output2_dims));
47     operator_output3 = std::vector<T>(NumElements(output3_dims));
48     subgraph_output1 = std::vector<T>(NumElements(output1_dims));
49     subgraph_output2 = std::vector<T>(NumElements(output2_dims));
50     subgraph_output3 = std::vector<T>(NumElements(output3_dims));
51 
52     signed_zero_point = i8dist(rng);
53     unsigned_zero_point = u8dist(rng);
54     scale = scale_dist(rng);
55 
56     batch_size = 1;
57     input_stride = 1;
58     for (size_t i = 0; i < axis; i++) {
59       batch_size *= input_dims[i];
60     }
61 
62     for (size_t i = axis; i < input_dims.size(); i++) {
63       input_stride *= input_dims[i];
64     }
65     channels = input_stride / 3;
66   }
67 
RandomShape()68   std::vector<size_t> RandomShape()
69   {
70     std::vector<size_t> dims(shape_dist(rng));
71     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
72     return dims;
73   }
74 
RandomAxis(const std::vector<size_t> & dims)75   size_t RandomAxis(const std::vector<size_t>& dims)
76   {
77     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
78   }
79 
NumElements(const std::vector<size_t> & dims)80   size_t NumElements(const std::vector<size_t>& dims)
81   {
82     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
83   }
84 
85   std::unique_ptr<std::random_device> random_device;
86   std::mt19937 rng;
87   std::uniform_int_distribution<size_t> shape_dist;
88   std::uniform_int_distribution<size_t> dim_dist;
89   std::uniform_real_distribution<float> f32dist;
90   std::uniform_int_distribution<int32_t> i8dist;
91   std::uniform_int_distribution<int32_t> u8dist;
92   std::uniform_real_distribution<float> scale_dist;
93 
94   uint32_t output1_id;
95   uint32_t output2_id;
96   uint32_t output3_id;
97   uint32_t input_id;
98 
99   std::vector<size_t> output1_dims;
100   std::vector<size_t> output2_dims;
101   std::vector<size_t> output3_dims;
102   std::vector<size_t> input_dims;
103 
104   size_t axis;
105   size_t batch_size;
106   size_t channels;
107   size_t input_stride;
108 
109   int32_t signed_zero_point;
110   int32_t unsigned_zero_point;
111   float scale;
112 
113   std::vector<T> operator_output1;
114   std::vector<T> operator_output2;
115   std::vector<T> operator_output3;
116   std::vector<T> subgraph_output1;
117   std::vector<T> subgraph_output2;
118   std::vector<T> subgraph_output3;
119   std::vector<T> input;
120 };
121 
122 using EvenSplit3TestQS8 = EvenSplit3Test<int8_t>;
123 using EvenSplit3TestQU8 = EvenSplit3Test<uint8_t>;
124 using EvenSplit3TestF32 = EvenSplit3Test<float>;
125 
TEST_F(EvenSplit3TestQS8,define)126 TEST_F(EvenSplit3TestQS8, define)
127 {
128   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
129 
130   xnn_subgraph_t subgraph = nullptr;
131   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
132   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
133 
134   input_id = XNN_INVALID_NODE_ID;
135   ASSERT_EQ(
136     xnn_status_success,
137     xnn_define_quantized_tensor_value(
138       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
139       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
140   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
141 
142   output1_id = XNN_INVALID_NODE_ID;
143   ASSERT_EQ(
144     xnn_status_success,
145     xnn_define_quantized_tensor_value(
146       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
147       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
148   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
149 
150   output2_id = XNN_INVALID_NODE_ID;
151   ASSERT_EQ(
152     xnn_status_success,
153     xnn_define_quantized_tensor_value(
154       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
155       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
156   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
157 
158   output3_id = XNN_INVALID_NODE_ID;
159   ASSERT_EQ(
160     xnn_status_success,
161     xnn_define_quantized_tensor_value(
162       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
163       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
164   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
165 
166   ASSERT_EQ(
167     xnn_status_success,
168     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
169 
170   ASSERT_EQ(subgraph->num_nodes, 1);
171   const struct xnn_node* node = &subgraph->nodes[0];
172   ASSERT_EQ(node->type, xnn_node_type_even_split3);
173   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
174   ASSERT_EQ(node->params.even_split.axis, axis);
175   ASSERT_EQ(node->num_inputs, 1);
176   ASSERT_EQ(node->inputs[0], input_id);
177   ASSERT_EQ(node->num_outputs, 3);
178   ASSERT_EQ(node->outputs[0], output1_id);
179   ASSERT_EQ(node->outputs[1], output2_id);
180   ASSERT_EQ(node->outputs[2], output3_id);
181   ASSERT_EQ(node->flags, 0);
182 }
183 
TEST_F(EvenSplit3TestQU8,define)184 TEST_F(EvenSplit3TestQU8, define)
185 {
186   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
187 
188   xnn_subgraph_t subgraph = nullptr;
189   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
190   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
191 
192   input_id = XNN_INVALID_NODE_ID;
193   ASSERT_EQ(
194     xnn_status_success,
195     xnn_define_quantized_tensor_value(
196       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
197       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
198   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
199 
200   output1_id = XNN_INVALID_NODE_ID;
201   ASSERT_EQ(
202     xnn_status_success,
203     xnn_define_quantized_tensor_value(
204       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
205       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
206   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
207 
208   output2_id = XNN_INVALID_NODE_ID;
209   ASSERT_EQ(
210     xnn_status_success,
211     xnn_define_quantized_tensor_value(
212       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
213       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
214   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
215 
216   output3_id = XNN_INVALID_NODE_ID;
217   ASSERT_EQ(
218     xnn_status_success,
219     xnn_define_quantized_tensor_value(
220       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
221       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
222   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
223 
224   ASSERT_EQ(
225     xnn_status_success,
226     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
227 
228   ASSERT_EQ(subgraph->num_nodes, 1);
229   const struct xnn_node* node = &subgraph->nodes[0];
230   ASSERT_EQ(node->type, xnn_node_type_even_split3);
231   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
232   ASSERT_EQ(node->params.even_split.axis, axis);
233   ASSERT_EQ(node->num_inputs, 1);
234   ASSERT_EQ(node->inputs[0], input_id);
235   ASSERT_EQ(node->num_outputs, 3);
236   ASSERT_EQ(node->outputs[0], output1_id);
237   ASSERT_EQ(node->outputs[1], output2_id);
238   ASSERT_EQ(node->outputs[2], output3_id);
239   ASSERT_EQ(node->flags, 0);
240 }
241 
TEST_F(EvenSplit3TestF32,define)242 TEST_F(EvenSplit3TestF32, define)
243 {
244   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
245 
246   xnn_subgraph_t subgraph = nullptr;
247   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
248   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
249 
250   input_id = XNN_INVALID_NODE_ID;
251   ASSERT_EQ(
252     xnn_status_success, xnn_define_tensor_value(
253                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
254                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
255   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
256 
257   output1_id = XNN_INVALID_NODE_ID;
258   ASSERT_EQ(
259     xnn_status_success, xnn_define_tensor_value(
260                           subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
261                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
262   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
263 
264   output2_id = XNN_INVALID_NODE_ID;
265   ASSERT_EQ(
266     xnn_status_success, xnn_define_tensor_value(
267                           subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
268                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
269   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
270 
271   output3_id = XNN_INVALID_NODE_ID;
272   ASSERT_EQ(
273     xnn_status_success, xnn_define_tensor_value(
274                           subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
275                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
276   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
277 
278   ASSERT_EQ(
279     xnn_status_success,
280     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
281 
282   ASSERT_EQ(subgraph->num_nodes, 1);
283   const struct xnn_node* node = &subgraph->nodes[0];
284   ASSERT_EQ(node->type, xnn_node_type_even_split3);
285   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
286   ASSERT_EQ(node->params.even_split.axis, axis);
287   ASSERT_EQ(node->num_inputs, 1);
288   ASSERT_EQ(node->inputs[0], input_id);
289   ASSERT_EQ(node->num_outputs, 3);
290   ASSERT_EQ(node->outputs[0], output1_id);
291   ASSERT_EQ(node->outputs[1], output2_id);
292   ASSERT_EQ(node->outputs[2], output3_id);
293   ASSERT_EQ(node->flags, 0);
294 }
295 
TEST_F(EvenSplit3TestQS8,matches_operator_api)296 TEST_F(EvenSplit3TestQS8, matches_operator_api)
297 {
298   std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
299   std::fill(operator_output1.begin(), operator_output1.end(), INT8_C(0xA5));
300   std::fill(operator_output2.begin(), operator_output2.end(), INT8_C(0xA5));
301   std::fill(operator_output3.begin(), operator_output3.end(), INT8_C(0xA5));
302   std::fill(subgraph_output1.begin(), subgraph_output1.end(), INT8_C(0xA5));
303   std::fill(subgraph_output2.begin(), subgraph_output2.end(), INT8_C(0xA5));
304   std::fill(subgraph_output3.begin(), subgraph_output3.end(), INT8_C(0xA5));
305 
306   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
307 
308   xnn_operator_t op1 = nullptr;
309   xnn_operator_t op2 = nullptr;
310   xnn_operator_t op3 = nullptr;
311 
312   // Call operator API.
313   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
314   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
315   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
316   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
317   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
318   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
319 
320   ASSERT_EQ(
321     xnn_status_success,
322     xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
323   ASSERT_EQ(
324     xnn_status_success,
325     xnn_setup_copy_nc_x8(
326       op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
327   ASSERT_EQ(
328     xnn_status_success, xnn_setup_copy_nc_x8(
329                           op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
330                           nullptr /* thread pool */));
331 
332   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
333   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
334   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
335 
336   // Call subgraph API.
337   xnn_subgraph_t subgraph = nullptr;
338   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
339   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
340 
341   input_id = XNN_INVALID_NODE_ID;
342   ASSERT_EQ(
343     xnn_status_success,
344     xnn_define_quantized_tensor_value(
345       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
346       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
347   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
348 
349   output1_id = XNN_INVALID_NODE_ID;
350   ASSERT_EQ(
351     xnn_status_success,
352     xnn_define_quantized_tensor_value(
353       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
354       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
355   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
356 
357   output2_id = XNN_INVALID_NODE_ID;
358   ASSERT_EQ(
359     xnn_status_success,
360     xnn_define_quantized_tensor_value(
361       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
362       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
363   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
364 
365   output3_id = XNN_INVALID_NODE_ID;
366   ASSERT_EQ(
367     xnn_status_success,
368     xnn_define_quantized_tensor_value(
369       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
370       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
371   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
372 
373   ASSERT_EQ(
374     xnn_status_success,
375     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
376 
377   xnn_runtime_t runtime = nullptr;
378   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
379   ASSERT_NE(nullptr, runtime);
380   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
381   std::array<xnn_external_value, 4> external = {
382     xnn_external_value{input_id, input.data()},
383     xnn_external_value{output1_id, subgraph_output1.data()},
384     xnn_external_value{output2_id, subgraph_output2.data()},
385     xnn_external_value{output3_id, subgraph_output3.data()},
386   };
387   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
388   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
389 
390   ASSERT_EQ(subgraph_output1, operator_output1);
391   ASSERT_EQ(subgraph_output2, operator_output2);
392   ASSERT_EQ(subgraph_output3, operator_output3);
393 }
394 
TEST_F(EvenSplit3TestQU8,matches_operator_api)395 TEST_F(EvenSplit3TestQU8, matches_operator_api)
396 {
397   std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
398   std::fill(operator_output1.begin(), operator_output1.end(), UINT8_C(0xA5));
399   std::fill(operator_output2.begin(), operator_output2.end(), UINT8_C(0xA5));
400   std::fill(operator_output3.begin(), operator_output3.end(), UINT8_C(0xA5));
401   std::fill(subgraph_output1.begin(), subgraph_output1.end(), UINT8_C(0xA5));
402   std::fill(subgraph_output2.begin(), subgraph_output2.end(), UINT8_C(0xA5));
403   std::fill(subgraph_output3.begin(), subgraph_output3.end(), UINT8_C(0xA5));
404 
405   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
406 
407   xnn_operator_t op1 = nullptr;
408   xnn_operator_t op2 = nullptr;
409   xnn_operator_t op3 = nullptr;
410 
411   // Call operator API.
412   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
413   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
414   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
415   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
416   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
417   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
418 
419   ASSERT_EQ(
420     xnn_status_success,
421     xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
422   ASSERT_EQ(
423     xnn_status_success,
424     xnn_setup_copy_nc_x8(
425       op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
426   ASSERT_EQ(
427     xnn_status_success, xnn_setup_copy_nc_x8(
428                           op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
429                           nullptr /* thread pool */));
430 
431   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
432   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
433   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
434 
435   // Call subgraph API.
436   xnn_subgraph_t subgraph = nullptr;
437   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
438   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
439 
440   input_id = XNN_INVALID_NODE_ID;
441   ASSERT_EQ(
442     xnn_status_success,
443     xnn_define_quantized_tensor_value(
444       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
445       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
446   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
447 
448   output1_id = XNN_INVALID_NODE_ID;
449   ASSERT_EQ(
450     xnn_status_success,
451     xnn_define_quantized_tensor_value(
452       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
453       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
454   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
455 
456   output2_id = XNN_INVALID_NODE_ID;
457   ASSERT_EQ(
458     xnn_status_success,
459     xnn_define_quantized_tensor_value(
460       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
461       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
462   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
463 
464   output3_id = XNN_INVALID_NODE_ID;
465   ASSERT_EQ(
466     xnn_status_success,
467     xnn_define_quantized_tensor_value(
468       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
469       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
470   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
471 
472   ASSERT_EQ(
473     xnn_status_success,
474     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
475 
476   xnn_runtime_t runtime = nullptr;
477   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
478   ASSERT_NE(nullptr, runtime);
479   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
480   std::array<xnn_external_value, 4> external = {
481     xnn_external_value{input_id, input.data()},
482     xnn_external_value{output1_id, subgraph_output1.data()},
483     xnn_external_value{output2_id, subgraph_output2.data()},
484     xnn_external_value{output3_id, subgraph_output3.data()},
485   };
486   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
487   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
488 
489   ASSERT_EQ(subgraph_output1, operator_output1);
490   ASSERT_EQ(subgraph_output2, operator_output2);
491   ASSERT_EQ(subgraph_output3, operator_output3);
492 }
493 
TEST_F(EvenSplit3TestF32,matches_operator_api)494 TEST_F(EvenSplit3TestF32, matches_operator_api)
495 {
496   std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
497   std::fill(operator_output1.begin(), operator_output1.end(), std::nanf(""));
498   std::fill(operator_output2.begin(), operator_output2.end(), std::nanf(""));
499   std::fill(operator_output3.begin(), operator_output3.end(), std::nanf(""));
500   std::fill(subgraph_output1.begin(), subgraph_output1.end(), std::nanf(""));
501   std::fill(subgraph_output2.begin(), subgraph_output2.end(), std::nanf(""));
502   std::fill(subgraph_output3.begin(), subgraph_output3.end(), std::nanf(""));
503 
504   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
505 
506   xnn_operator_t op1 = nullptr;
507   xnn_operator_t op2 = nullptr;
508   xnn_operator_t op3 = nullptr;
509 
510   // Call operator API.
511   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op1));
512   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
513   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op2));
514   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
515   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op3));
516   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
517 
518   ASSERT_EQ(
519     xnn_status_success,
520     xnn_setup_copy_nc_x32(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
521   ASSERT_EQ(
522     xnn_status_success,
523     xnn_setup_copy_nc_x32(
524       op2, batch_size, (uint32_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
525   ASSERT_EQ(
526     xnn_status_success, xnn_setup_copy_nc_x32(
527                           op3, batch_size, (uint32_t*) input.data() + op1->channels * 2, operator_output3.data(),
528                           nullptr /* thread pool */));
529 
530   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
531   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
532   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
533 
534   // Call subgraph API.
535   xnn_subgraph_t subgraph = nullptr;
536   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
537   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
538 
539   input_id = XNN_INVALID_NODE_ID;
540   ASSERT_EQ(
541     xnn_status_success, xnn_define_tensor_value(
542                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
543                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
544   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
545 
546   output1_id = XNN_INVALID_NODE_ID;
547   ASSERT_EQ(
548     xnn_status_success, xnn_define_tensor_value(
549                           subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
550                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
551   ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
552 
553   output2_id = XNN_INVALID_NODE_ID;
554   ASSERT_EQ(
555     xnn_status_success, xnn_define_tensor_value(
556                           subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
557                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
558   ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
559 
560   output3_id = XNN_INVALID_NODE_ID;
561   ASSERT_EQ(
562     xnn_status_success, xnn_define_tensor_value(
563                           subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
564                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
565   ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
566 
567   ASSERT_EQ(
568     xnn_status_success,
569     xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
570 
571   xnn_runtime_t runtime = nullptr;
572   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
573   ASSERT_NE(nullptr, runtime);
574   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
575   std::array<xnn_external_value, 4> external = {
576     xnn_external_value{input_id, input.data()},
577     xnn_external_value{output1_id, subgraph_output1.data()},
578     xnn_external_value{output2_id, subgraph_output2.data()},
579     xnn_external_value{output3_id, subgraph_output3.data()},
580   };
581   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
582   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
583 
584   ASSERT_EQ(subgraph_output1, operator_output1);
585   ASSERT_EQ(subgraph_output2, operator_output2);
586   ASSERT_EQ(subgraph_output3, operator_output3);
587 }
588