1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19
20 #include <gtest/gtest.h>
21
22 template <typename T> class EvenSplit3Test : public ::testing::Test {
23 protected:
EvenSplit3Test()24 EvenSplit3Test()
25 {
26 random_device = std::unique_ptr<std::random_device>(new std::random_device());
27 rng = std::mt19937((*random_device)());
28 shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29 dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30 f32dist = std::uniform_real_distribution<float>();
31 i8dist =
32 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33 u8dist =
34 std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35 scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36
37 output1_dims = RandomShape();
38 output2_dims = output1_dims;
39 output3_dims = output1_dims;
40 input_dims = output1_dims;
41 axis = RandomAxis(output1_dims);
42 input_dims[axis] = output1_dims[axis] + output2_dims[axis] + output3_dims[axis];
43
44 input = std::vector<T>(NumElements(input_dims));
45 operator_output1 = std::vector<T>(NumElements(output1_dims));
46 operator_output2 = std::vector<T>(NumElements(output2_dims));
47 operator_output3 = std::vector<T>(NumElements(output3_dims));
48 subgraph_output1 = std::vector<T>(NumElements(output1_dims));
49 subgraph_output2 = std::vector<T>(NumElements(output2_dims));
50 subgraph_output3 = std::vector<T>(NumElements(output3_dims));
51
52 signed_zero_point = i8dist(rng);
53 unsigned_zero_point = u8dist(rng);
54 scale = scale_dist(rng);
55
56 batch_size = 1;
57 input_stride = 1;
58 for (size_t i = 0; i < axis; i++) {
59 batch_size *= input_dims[i];
60 }
61
62 for (size_t i = axis; i < input_dims.size(); i++) {
63 input_stride *= input_dims[i];
64 }
65 channels = input_stride / 3;
66 }
67
RandomShape()68 std::vector<size_t> RandomShape()
69 {
70 std::vector<size_t> dims(shape_dist(rng));
71 std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
72 return dims;
73 }
74
RandomAxis(const std::vector<size_t> & dims)75 size_t RandomAxis(const std::vector<size_t>& dims)
76 {
77 return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
78 }
79
NumElements(const std::vector<size_t> & dims)80 size_t NumElements(const std::vector<size_t>& dims)
81 {
82 return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
83 }
84
85 std::unique_ptr<std::random_device> random_device;
86 std::mt19937 rng;
87 std::uniform_int_distribution<size_t> shape_dist;
88 std::uniform_int_distribution<size_t> dim_dist;
89 std::uniform_real_distribution<float> f32dist;
90 std::uniform_int_distribution<int32_t> i8dist;
91 std::uniform_int_distribution<int32_t> u8dist;
92 std::uniform_real_distribution<float> scale_dist;
93
94 uint32_t output1_id;
95 uint32_t output2_id;
96 uint32_t output3_id;
97 uint32_t input_id;
98
99 std::vector<size_t> output1_dims;
100 std::vector<size_t> output2_dims;
101 std::vector<size_t> output3_dims;
102 std::vector<size_t> input_dims;
103
104 size_t axis;
105 size_t batch_size;
106 size_t channels;
107 size_t input_stride;
108
109 int32_t signed_zero_point;
110 int32_t unsigned_zero_point;
111 float scale;
112
113 std::vector<T> operator_output1;
114 std::vector<T> operator_output2;
115 std::vector<T> operator_output3;
116 std::vector<T> subgraph_output1;
117 std::vector<T> subgraph_output2;
118 std::vector<T> subgraph_output3;
119 std::vector<T> input;
120 };
121
122 using EvenSplit3TestQS8 = EvenSplit3Test<int8_t>;
123 using EvenSplit3TestQU8 = EvenSplit3Test<uint8_t>;
124 using EvenSplit3TestF32 = EvenSplit3Test<float>;
125
TEST_F(EvenSplit3TestQS8,define)126 TEST_F(EvenSplit3TestQS8, define)
127 {
128 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
129
130 xnn_subgraph_t subgraph = nullptr;
131 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
132 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
133
134 input_id = XNN_INVALID_NODE_ID;
135 ASSERT_EQ(
136 xnn_status_success,
137 xnn_define_quantized_tensor_value(
138 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
139 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
140 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
141
142 output1_id = XNN_INVALID_NODE_ID;
143 ASSERT_EQ(
144 xnn_status_success,
145 xnn_define_quantized_tensor_value(
146 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
147 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
148 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
149
150 output2_id = XNN_INVALID_NODE_ID;
151 ASSERT_EQ(
152 xnn_status_success,
153 xnn_define_quantized_tensor_value(
154 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
155 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
156 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
157
158 output3_id = XNN_INVALID_NODE_ID;
159 ASSERT_EQ(
160 xnn_status_success,
161 xnn_define_quantized_tensor_value(
162 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
163 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
164 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
165
166 ASSERT_EQ(
167 xnn_status_success,
168 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
169
170 ASSERT_EQ(subgraph->num_nodes, 1);
171 const struct xnn_node* node = &subgraph->nodes[0];
172 ASSERT_EQ(node->type, xnn_node_type_even_split3);
173 ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
174 ASSERT_EQ(node->params.even_split.axis, axis);
175 ASSERT_EQ(node->num_inputs, 1);
176 ASSERT_EQ(node->inputs[0], input_id);
177 ASSERT_EQ(node->num_outputs, 3);
178 ASSERT_EQ(node->outputs[0], output1_id);
179 ASSERT_EQ(node->outputs[1], output2_id);
180 ASSERT_EQ(node->outputs[2], output3_id);
181 ASSERT_EQ(node->flags, 0);
182 }
183
TEST_F(EvenSplit3TestQU8,define)184 TEST_F(EvenSplit3TestQU8, define)
185 {
186 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
187
188 xnn_subgraph_t subgraph = nullptr;
189 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
190 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
191
192 input_id = XNN_INVALID_NODE_ID;
193 ASSERT_EQ(
194 xnn_status_success,
195 xnn_define_quantized_tensor_value(
196 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
197 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
198 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
199
200 output1_id = XNN_INVALID_NODE_ID;
201 ASSERT_EQ(
202 xnn_status_success,
203 xnn_define_quantized_tensor_value(
204 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
205 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
206 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
207
208 output2_id = XNN_INVALID_NODE_ID;
209 ASSERT_EQ(
210 xnn_status_success,
211 xnn_define_quantized_tensor_value(
212 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
213 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
214 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
215
216 output3_id = XNN_INVALID_NODE_ID;
217 ASSERT_EQ(
218 xnn_status_success,
219 xnn_define_quantized_tensor_value(
220 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
221 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
222 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
223
224 ASSERT_EQ(
225 xnn_status_success,
226 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
227
228 ASSERT_EQ(subgraph->num_nodes, 1);
229 const struct xnn_node* node = &subgraph->nodes[0];
230 ASSERT_EQ(node->type, xnn_node_type_even_split3);
231 ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
232 ASSERT_EQ(node->params.even_split.axis, axis);
233 ASSERT_EQ(node->num_inputs, 1);
234 ASSERT_EQ(node->inputs[0], input_id);
235 ASSERT_EQ(node->num_outputs, 3);
236 ASSERT_EQ(node->outputs[0], output1_id);
237 ASSERT_EQ(node->outputs[1], output2_id);
238 ASSERT_EQ(node->outputs[2], output3_id);
239 ASSERT_EQ(node->flags, 0);
240 }
241
TEST_F(EvenSplit3TestF32,define)242 TEST_F(EvenSplit3TestF32, define)
243 {
244 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
245
246 xnn_subgraph_t subgraph = nullptr;
247 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
248 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
249
250 input_id = XNN_INVALID_NODE_ID;
251 ASSERT_EQ(
252 xnn_status_success, xnn_define_tensor_value(
253 subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
254 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
255 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
256
257 output1_id = XNN_INVALID_NODE_ID;
258 ASSERT_EQ(
259 xnn_status_success, xnn_define_tensor_value(
260 subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
261 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
262 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
263
264 output2_id = XNN_INVALID_NODE_ID;
265 ASSERT_EQ(
266 xnn_status_success, xnn_define_tensor_value(
267 subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
268 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
269 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
270
271 output3_id = XNN_INVALID_NODE_ID;
272 ASSERT_EQ(
273 xnn_status_success, xnn_define_tensor_value(
274 subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
275 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
276 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
277
278 ASSERT_EQ(
279 xnn_status_success,
280 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
281
282 ASSERT_EQ(subgraph->num_nodes, 1);
283 const struct xnn_node* node = &subgraph->nodes[0];
284 ASSERT_EQ(node->type, xnn_node_type_even_split3);
285 ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
286 ASSERT_EQ(node->params.even_split.axis, axis);
287 ASSERT_EQ(node->num_inputs, 1);
288 ASSERT_EQ(node->inputs[0], input_id);
289 ASSERT_EQ(node->num_outputs, 3);
290 ASSERT_EQ(node->outputs[0], output1_id);
291 ASSERT_EQ(node->outputs[1], output2_id);
292 ASSERT_EQ(node->outputs[2], output3_id);
293 ASSERT_EQ(node->flags, 0);
294 }
295
TEST_F(EvenSplit3TestQS8,matches_operator_api)296 TEST_F(EvenSplit3TestQS8, matches_operator_api)
297 {
298 std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
299 std::fill(operator_output1.begin(), operator_output1.end(), INT8_C(0xA5));
300 std::fill(operator_output2.begin(), operator_output2.end(), INT8_C(0xA5));
301 std::fill(operator_output3.begin(), operator_output3.end(), INT8_C(0xA5));
302 std::fill(subgraph_output1.begin(), subgraph_output1.end(), INT8_C(0xA5));
303 std::fill(subgraph_output2.begin(), subgraph_output2.end(), INT8_C(0xA5));
304 std::fill(subgraph_output3.begin(), subgraph_output3.end(), INT8_C(0xA5));
305
306 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
307
308 xnn_operator_t op1 = nullptr;
309 xnn_operator_t op2 = nullptr;
310 xnn_operator_t op3 = nullptr;
311
312 // Call operator API.
313 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
314 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
315 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
316 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
317 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
318 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
319
320 ASSERT_EQ(
321 xnn_status_success,
322 xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
323 ASSERT_EQ(
324 xnn_status_success,
325 xnn_setup_copy_nc_x8(
326 op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
327 ASSERT_EQ(
328 xnn_status_success, xnn_setup_copy_nc_x8(
329 op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
330 nullptr /* thread pool */));
331
332 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
333 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
334 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
335
336 // Call subgraph API.
337 xnn_subgraph_t subgraph = nullptr;
338 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
339 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
340
341 input_id = XNN_INVALID_NODE_ID;
342 ASSERT_EQ(
343 xnn_status_success,
344 xnn_define_quantized_tensor_value(
345 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
346 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
347 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
348
349 output1_id = XNN_INVALID_NODE_ID;
350 ASSERT_EQ(
351 xnn_status_success,
352 xnn_define_quantized_tensor_value(
353 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
354 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
355 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
356
357 output2_id = XNN_INVALID_NODE_ID;
358 ASSERT_EQ(
359 xnn_status_success,
360 xnn_define_quantized_tensor_value(
361 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
362 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
363 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
364
365 output3_id = XNN_INVALID_NODE_ID;
366 ASSERT_EQ(
367 xnn_status_success,
368 xnn_define_quantized_tensor_value(
369 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
370 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
371 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
372
373 ASSERT_EQ(
374 xnn_status_success,
375 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
376
377 xnn_runtime_t runtime = nullptr;
378 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
379 ASSERT_NE(nullptr, runtime);
380 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
381 std::array<xnn_external_value, 4> external = {
382 xnn_external_value{input_id, input.data()},
383 xnn_external_value{output1_id, subgraph_output1.data()},
384 xnn_external_value{output2_id, subgraph_output2.data()},
385 xnn_external_value{output3_id, subgraph_output3.data()},
386 };
387 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
388 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
389
390 ASSERT_EQ(subgraph_output1, operator_output1);
391 ASSERT_EQ(subgraph_output2, operator_output2);
392 ASSERT_EQ(subgraph_output3, operator_output3);
393 }
394
TEST_F(EvenSplit3TestQU8,matches_operator_api)395 TEST_F(EvenSplit3TestQU8, matches_operator_api)
396 {
397 std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
398 std::fill(operator_output1.begin(), operator_output1.end(), UINT8_C(0xA5));
399 std::fill(operator_output2.begin(), operator_output2.end(), UINT8_C(0xA5));
400 std::fill(operator_output3.begin(), operator_output3.end(), UINT8_C(0xA5));
401 std::fill(subgraph_output1.begin(), subgraph_output1.end(), UINT8_C(0xA5));
402 std::fill(subgraph_output2.begin(), subgraph_output2.end(), UINT8_C(0xA5));
403 std::fill(subgraph_output3.begin(), subgraph_output3.end(), UINT8_C(0xA5));
404
405 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
406
407 xnn_operator_t op1 = nullptr;
408 xnn_operator_t op2 = nullptr;
409 xnn_operator_t op3 = nullptr;
410
411 // Call operator API.
412 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
413 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
414 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
415 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
416 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
417 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
418
419 ASSERT_EQ(
420 xnn_status_success,
421 xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
422 ASSERT_EQ(
423 xnn_status_success,
424 xnn_setup_copy_nc_x8(
425 op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
426 ASSERT_EQ(
427 xnn_status_success, xnn_setup_copy_nc_x8(
428 op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
429 nullptr /* thread pool */));
430
431 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
432 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
433 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
434
435 // Call subgraph API.
436 xnn_subgraph_t subgraph = nullptr;
437 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
438 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
439
440 input_id = XNN_INVALID_NODE_ID;
441 ASSERT_EQ(
442 xnn_status_success,
443 xnn_define_quantized_tensor_value(
444 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
445 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
446 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
447
448 output1_id = XNN_INVALID_NODE_ID;
449 ASSERT_EQ(
450 xnn_status_success,
451 xnn_define_quantized_tensor_value(
452 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
453 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
454 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
455
456 output2_id = XNN_INVALID_NODE_ID;
457 ASSERT_EQ(
458 xnn_status_success,
459 xnn_define_quantized_tensor_value(
460 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
461 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
462 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
463
464 output3_id = XNN_INVALID_NODE_ID;
465 ASSERT_EQ(
466 xnn_status_success,
467 xnn_define_quantized_tensor_value(
468 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
469 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
470 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
471
472 ASSERT_EQ(
473 xnn_status_success,
474 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
475
476 xnn_runtime_t runtime = nullptr;
477 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
478 ASSERT_NE(nullptr, runtime);
479 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
480 std::array<xnn_external_value, 4> external = {
481 xnn_external_value{input_id, input.data()},
482 xnn_external_value{output1_id, subgraph_output1.data()},
483 xnn_external_value{output2_id, subgraph_output2.data()},
484 xnn_external_value{output3_id, subgraph_output3.data()},
485 };
486 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
487 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
488
489 ASSERT_EQ(subgraph_output1, operator_output1);
490 ASSERT_EQ(subgraph_output2, operator_output2);
491 ASSERT_EQ(subgraph_output3, operator_output3);
492 }
493
TEST_F(EvenSplit3TestF32,matches_operator_api)494 TEST_F(EvenSplit3TestF32, matches_operator_api)
495 {
496 std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
497 std::fill(operator_output1.begin(), operator_output1.end(), std::nanf(""));
498 std::fill(operator_output2.begin(), operator_output2.end(), std::nanf(""));
499 std::fill(operator_output3.begin(), operator_output3.end(), std::nanf(""));
500 std::fill(subgraph_output1.begin(), subgraph_output1.end(), std::nanf(""));
501 std::fill(subgraph_output2.begin(), subgraph_output2.end(), std::nanf(""));
502 std::fill(subgraph_output3.begin(), subgraph_output3.end(), std::nanf(""));
503
504 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
505
506 xnn_operator_t op1 = nullptr;
507 xnn_operator_t op2 = nullptr;
508 xnn_operator_t op3 = nullptr;
509
510 // Call operator API.
511 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op1));
512 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
513 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op2));
514 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
515 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op3));
516 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
517
518 ASSERT_EQ(
519 xnn_status_success,
520 xnn_setup_copy_nc_x32(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
521 ASSERT_EQ(
522 xnn_status_success,
523 xnn_setup_copy_nc_x32(
524 op2, batch_size, (uint32_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
525 ASSERT_EQ(
526 xnn_status_success, xnn_setup_copy_nc_x32(
527 op3, batch_size, (uint32_t*) input.data() + op1->channels * 2, operator_output3.data(),
528 nullptr /* thread pool */));
529
530 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
531 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
532 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
533
534 // Call subgraph API.
535 xnn_subgraph_t subgraph = nullptr;
536 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
537 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
538
539 input_id = XNN_INVALID_NODE_ID;
540 ASSERT_EQ(
541 xnn_status_success, xnn_define_tensor_value(
542 subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
543 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
544 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
545
546 output1_id = XNN_INVALID_NODE_ID;
547 ASSERT_EQ(
548 xnn_status_success, xnn_define_tensor_value(
549 subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
550 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
551 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
552
553 output2_id = XNN_INVALID_NODE_ID;
554 ASSERT_EQ(
555 xnn_status_success, xnn_define_tensor_value(
556 subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
557 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
558 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
559
560 output3_id = XNN_INVALID_NODE_ID;
561 ASSERT_EQ(
562 xnn_status_success, xnn_define_tensor_value(
563 subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
564 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
565 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
566
567 ASSERT_EQ(
568 xnn_status_success,
569 xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0));
570
571 xnn_runtime_t runtime = nullptr;
572 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
573 ASSERT_NE(nullptr, runtime);
574 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
575 std::array<xnn_external_value, 4> external = {
576 xnn_external_value{input_id, input.data()},
577 xnn_external_value{output1_id, subgraph_output1.data()},
578 xnn_external_value{output2_id, subgraph_output2.data()},
579 xnn_external_value{output3_id, subgraph_output3.data()},
580 };
581 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
582 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
583
584 ASSERT_EQ(subgraph_output1, operator_output1);
585 ASSERT_EQ(subgraph_output2, operator_output2);
586 ASSERT_EQ(subgraph_output3, operator_output3);
587 }
588