1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19
20 #include <gtest/gtest.h>
21
22 template <typename T> class EvenSplit4Test : public ::testing::Test {
23 protected:
EvenSplit4Test()24 EvenSplit4Test()
25 {
26 random_device = std::unique_ptr<std::random_device>(new std::random_device());
27 rng = std::mt19937((*random_device)());
28 shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29 dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30 f32dist = std::uniform_real_distribution<float>();
31 i8dist =
32 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33 u8dist =
34 std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35 scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36
37 output1_dims = RandomShape();
38 output2_dims = output1_dims;
39 output3_dims = output1_dims;
40 output4_dims = output1_dims;
41 input_dims = output1_dims;
42 axis = RandomAxis(output1_dims);
43 input_dims[axis] = output1_dims[axis] + output2_dims[axis] + output3_dims[axis] + output4_dims[axis];
44
45 input = std::vector<T>(NumElements(input_dims));
46 operator_output1 = std::vector<T>(NumElements(output1_dims));
47 operator_output2 = std::vector<T>(NumElements(output2_dims));
48 operator_output3 = std::vector<T>(NumElements(output3_dims));
49 operator_output4 = std::vector<T>(NumElements(output4_dims));
50 subgraph_output1 = std::vector<T>(NumElements(output1_dims));
51 subgraph_output2 = std::vector<T>(NumElements(output2_dims));
52 subgraph_output3 = std::vector<T>(NumElements(output3_dims));
53 subgraph_output4 = std::vector<T>(NumElements(output4_dims));
54
55 signed_zero_point = i8dist(rng);
56 unsigned_zero_point = u8dist(rng);
57 scale = scale_dist(rng);
58
59 batch_size = 1;
60 input_stride = 1;
61 for (size_t i = 0; i < axis; i++) {
62 batch_size *= input_dims[i];
63 }
64
65 for (size_t i = axis; i < input_dims.size(); i++) {
66 input_stride *= input_dims[i];
67 }
68 channels = input_stride / 4;
69 }
70
RandomShape()71 std::vector<size_t> RandomShape()
72 {
73 std::vector<size_t> dims(shape_dist(rng));
74 std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
75 return dims;
76 }
77
RandomAxis(const std::vector<size_t> & dims)78 size_t RandomAxis(const std::vector<size_t>& dims)
79 {
80 return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
81 }
82
NumElements(const std::vector<size_t> & dims)83 size_t NumElements(const std::vector<size_t>& dims)
84 {
85 return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
86 }
87
88 std::unique_ptr<std::random_device> random_device;
89 std::mt19937 rng;
90 std::uniform_int_distribution<size_t> shape_dist;
91 std::uniform_int_distribution<size_t> dim_dist;
92 std::uniform_real_distribution<float> f32dist;
93 std::uniform_int_distribution<int32_t> i8dist;
94 std::uniform_int_distribution<int32_t> u8dist;
95 std::uniform_real_distribution<float> scale_dist;
96
97 uint32_t output1_id;
98 uint32_t output2_id;
99 uint32_t output3_id;
100 uint32_t output4_id;
101 uint32_t input_id;
102
103 std::vector<size_t> output1_dims;
104 std::vector<size_t> output2_dims;
105 std::vector<size_t> output3_dims;
106 std::vector<size_t> output4_dims;
107 std::vector<size_t> input_dims;
108
109 size_t axis;
110 size_t batch_size;
111 size_t channels;
112 size_t input_stride;
113
114 int32_t signed_zero_point;
115 int32_t unsigned_zero_point;
116 float scale;
117
118 std::vector<T> operator_output1;
119 std::vector<T> operator_output2;
120 std::vector<T> operator_output3;
121 std::vector<T> operator_output4;
122 std::vector<T> subgraph_output1;
123 std::vector<T> subgraph_output2;
124 std::vector<T> subgraph_output3;
125 std::vector<T> subgraph_output4;
126 std::vector<T> input;
127 };
128
129 using EvenSplit4TestQS8 = EvenSplit4Test<int8_t>;
130 using EvenSplit4TestQU8 = EvenSplit4Test<uint8_t>;
131 using EvenSplit4TestF32 = EvenSplit4Test<float>;
132
TEST_F(EvenSplit4TestQS8,define)133 TEST_F(EvenSplit4TestQS8, define)
134 {
135 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
136
137 xnn_subgraph_t subgraph = nullptr;
138 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
139 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
140
141 input_id = XNN_INVALID_NODE_ID;
142 ASSERT_EQ(
143 xnn_status_success,
144 xnn_define_quantized_tensor_value(
145 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
146 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
147 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
148
149 output1_id = XNN_INVALID_NODE_ID;
150 ASSERT_EQ(
151 xnn_status_success,
152 xnn_define_quantized_tensor_value(
153 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
154 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
155 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
156
157 output2_id = XNN_INVALID_NODE_ID;
158 ASSERT_EQ(
159 xnn_status_success,
160 xnn_define_quantized_tensor_value(
161 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
162 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
163 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
164
165 output3_id = XNN_INVALID_NODE_ID;
166 ASSERT_EQ(
167 xnn_status_success,
168 xnn_define_quantized_tensor_value(
169 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
170 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
171 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
172
173 output4_id = XNN_INVALID_NODE_ID;
174 ASSERT_EQ(
175 xnn_status_success,
176 xnn_define_quantized_tensor_value(
177 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
178 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
179 ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
180
181 ASSERT_EQ(
182 xnn_status_success,
183 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
184
185 ASSERT_EQ(subgraph->num_nodes, 1);
186 const struct xnn_node* node = &subgraph->nodes[0];
187 ASSERT_EQ(node->type, xnn_node_type_even_split4);
188 ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
189 ASSERT_EQ(node->params.even_split.axis, axis);
190 ASSERT_EQ(node->num_inputs, 1);
191 ASSERT_EQ(node->inputs[0], input_id);
192 ASSERT_EQ(node->num_outputs, 4);
193 ASSERT_EQ(node->outputs[0], output1_id);
194 ASSERT_EQ(node->outputs[1], output2_id);
195 ASSERT_EQ(node->outputs[2], output3_id);
196 ASSERT_EQ(node->outputs[3], output4_id);
197 ASSERT_EQ(node->flags, 0);
198 }
199
TEST_F(EvenSplit4TestQU8,define)200 TEST_F(EvenSplit4TestQU8, define)
201 {
202 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
203
204 xnn_subgraph_t subgraph = nullptr;
205 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
206 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
207
208 input_id = XNN_INVALID_NODE_ID;
209 ASSERT_EQ(
210 xnn_status_success,
211 xnn_define_quantized_tensor_value(
212 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
213 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
214 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
215
216 output1_id = XNN_INVALID_NODE_ID;
217 ASSERT_EQ(
218 xnn_status_success,
219 xnn_define_quantized_tensor_value(
220 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
221 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
222 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
223
224 output2_id = XNN_INVALID_NODE_ID;
225 ASSERT_EQ(
226 xnn_status_success,
227 xnn_define_quantized_tensor_value(
228 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
229 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
230 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
231
232 output3_id = XNN_INVALID_NODE_ID;
233 ASSERT_EQ(
234 xnn_status_success,
235 xnn_define_quantized_tensor_value(
236 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
237 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
238 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
239
240 output4_id = XNN_INVALID_NODE_ID;
241 ASSERT_EQ(
242 xnn_status_success,
243 xnn_define_quantized_tensor_value(
244 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
245 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
246 ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
247
248 ASSERT_EQ(
249 xnn_status_success,
250 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
251
252 ASSERT_EQ(subgraph->num_nodes, 1);
253 const struct xnn_node* node = &subgraph->nodes[0];
254 ASSERT_EQ(node->type, xnn_node_type_even_split4);
255 ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
256 ASSERT_EQ(node->params.even_split.axis, axis);
257 ASSERT_EQ(node->num_inputs, 1);
258 ASSERT_EQ(node->inputs[0], input_id);
259 ASSERT_EQ(node->num_outputs, 4);
260 ASSERT_EQ(node->outputs[0], output1_id);
261 ASSERT_EQ(node->outputs[1], output2_id);
262 ASSERT_EQ(node->outputs[2], output3_id);
263 ASSERT_EQ(node->outputs[3], output4_id);
264 ASSERT_EQ(node->flags, 0);
265 }
266
TEST_F(EvenSplit4TestF32,define)267 TEST_F(EvenSplit4TestF32, define)
268 {
269 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
270
271 xnn_subgraph_t subgraph = nullptr;
272 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
273 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
274
275 input_id = XNN_INVALID_NODE_ID;
276 ASSERT_EQ(
277 xnn_status_success, xnn_define_tensor_value(
278 subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
279 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
280 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
281
282 output1_id = XNN_INVALID_NODE_ID;
283 ASSERT_EQ(
284 xnn_status_success, xnn_define_tensor_value(
285 subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
286 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
287 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
288
289 output2_id = XNN_INVALID_NODE_ID;
290 ASSERT_EQ(
291 xnn_status_success, xnn_define_tensor_value(
292 subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
293 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
294 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
295
296 output3_id = XNN_INVALID_NODE_ID;
297 ASSERT_EQ(
298 xnn_status_success, xnn_define_tensor_value(
299 subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
300 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
301 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
302
303 output4_id = XNN_INVALID_NODE_ID;
304 ASSERT_EQ(
305 xnn_status_success, xnn_define_tensor_value(
306 subgraph, xnn_datatype_fp32, output4_dims.size(), output4_dims.data(), nullptr, 4,
307 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
308 ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
309
310 ASSERT_EQ(
311 xnn_status_success,
312 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
313
314 ASSERT_EQ(subgraph->num_nodes, 1);
315 const struct xnn_node* node = &subgraph->nodes[0];
316 ASSERT_EQ(node->type, xnn_node_type_even_split4);
317 ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
318 ASSERT_EQ(node->params.even_split.axis, axis);
319 ASSERT_EQ(node->num_inputs, 1);
320 ASSERT_EQ(node->inputs[0], input_id);
321 ASSERT_EQ(node->num_outputs, 4);
322 ASSERT_EQ(node->outputs[0], output1_id);
323 ASSERT_EQ(node->outputs[1], output2_id);
324 ASSERT_EQ(node->outputs[2], output3_id);
325 ASSERT_EQ(node->outputs[3], output4_id);
326 ASSERT_EQ(node->flags, 0);
327 }
328
TEST_F(EvenSplit4TestQS8,matches_operator_api)329 TEST_F(EvenSplit4TestQS8, matches_operator_api)
330 {
331 std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
332 std::fill(operator_output1.begin(), operator_output1.end(), INT8_C(0xA5));
333 std::fill(operator_output2.begin(), operator_output2.end(), INT8_C(0xA5));
334 std::fill(operator_output3.begin(), operator_output3.end(), INT8_C(0xA5));
335 std::fill(operator_output4.begin(), operator_output4.end(), INT8_C(0xA5));
336 std::fill(subgraph_output1.begin(), subgraph_output1.end(), INT8_C(0xA5));
337 std::fill(subgraph_output2.begin(), subgraph_output2.end(), INT8_C(0xA5));
338 std::fill(subgraph_output3.begin(), subgraph_output3.end(), INT8_C(0xA5));
339 std::fill(subgraph_output4.begin(), subgraph_output4.end(), INT8_C(0xA5));
340
341 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
342
343 xnn_operator_t op1 = nullptr;
344 xnn_operator_t op2 = nullptr;
345 xnn_operator_t op3 = nullptr;
346 xnn_operator_t op4 = nullptr;
347
348 // Call operator API.
349 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
350 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
351 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
352 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
353 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
354 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
355 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op4));
356 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
357
358 ASSERT_EQ(
359 xnn_status_success,
360 xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
361 ASSERT_EQ(
362 xnn_status_success,
363 xnn_setup_copy_nc_x8(
364 op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
365 ASSERT_EQ(
366 xnn_status_success, xnn_setup_copy_nc_x8(
367 op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
368 nullptr /* thread pool */));
369 ASSERT_EQ(
370 xnn_status_success, xnn_setup_copy_nc_x8(
371 op4, batch_size, (uint8_t*) input.data() + op1->channels * 3, operator_output4.data(),
372 nullptr /* thread pool */));
373
374 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
375 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
376 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
377 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
378
379 // Call subgraph API.
380 xnn_subgraph_t subgraph = nullptr;
381 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
382 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
383
384 input_id = XNN_INVALID_NODE_ID;
385 ASSERT_EQ(
386 xnn_status_success,
387 xnn_define_quantized_tensor_value(
388 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
389 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
390 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
391
392 output1_id = XNN_INVALID_NODE_ID;
393 ASSERT_EQ(
394 xnn_status_success,
395 xnn_define_quantized_tensor_value(
396 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
397 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
398 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
399
400 output2_id = XNN_INVALID_NODE_ID;
401 ASSERT_EQ(
402 xnn_status_success,
403 xnn_define_quantized_tensor_value(
404 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
405 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
406 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
407
408 output3_id = XNN_INVALID_NODE_ID;
409 ASSERT_EQ(
410 xnn_status_success,
411 xnn_define_quantized_tensor_value(
412 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
413 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
414 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
415
416 output4_id = XNN_INVALID_NODE_ID;
417 ASSERT_EQ(
418 xnn_status_success,
419 xnn_define_quantized_tensor_value(
420 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
421 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
422 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
423
424 ASSERT_EQ(
425 xnn_status_success,
426 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
427
428 xnn_runtime_t runtime = nullptr;
429 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
430 ASSERT_NE(nullptr, runtime);
431 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
432 std::array<xnn_external_value, 5> external = {
433 xnn_external_value{input_id, input.data()},
434 xnn_external_value{output1_id, subgraph_output1.data()},
435 xnn_external_value{output2_id, subgraph_output2.data()},
436 xnn_external_value{output3_id, subgraph_output3.data()},
437 xnn_external_value{output4_id, subgraph_output4.data()},
438 };
439 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
440 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
441
442 ASSERT_EQ(subgraph_output1, operator_output1);
443 ASSERT_EQ(subgraph_output2, operator_output2);
444 ASSERT_EQ(subgraph_output3, operator_output3);
445 ASSERT_EQ(subgraph_output4, operator_output4);
446 }
447
TEST_F(EvenSplit4TestQU8,matches_operator_api)448 TEST_F(EvenSplit4TestQU8, matches_operator_api)
449 {
450 std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
451 std::fill(operator_output1.begin(), operator_output1.end(), UINT8_C(0xA5));
452 std::fill(operator_output2.begin(), operator_output2.end(), UINT8_C(0xA5));
453 std::fill(operator_output3.begin(), operator_output3.end(), UINT8_C(0xA5));
454 std::fill(operator_output4.begin(), operator_output4.end(), UINT8_C(0xA5));
455 std::fill(subgraph_output1.begin(), subgraph_output1.end(), UINT8_C(0xA5));
456 std::fill(subgraph_output2.begin(), subgraph_output2.end(), UINT8_C(0xA5));
457 std::fill(subgraph_output3.begin(), subgraph_output3.end(), UINT8_C(0xA5));
458 std::fill(subgraph_output4.begin(), subgraph_output4.end(), UINT8_C(0xA5));
459
460 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
461
462 xnn_operator_t op1 = nullptr;
463 xnn_operator_t op2 = nullptr;
464 xnn_operator_t op3 = nullptr;
465 xnn_operator_t op4 = nullptr;
466
467 // Call operator API.
468 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op1));
469 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
470 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op2));
471 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
472 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op3));
473 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
474 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels, input_stride, channels, /*flags=*/0, &op4));
475 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
476
477 ASSERT_EQ(
478 xnn_status_success,
479 xnn_setup_copy_nc_x8(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
480 ASSERT_EQ(
481 xnn_status_success,
482 xnn_setup_copy_nc_x8(
483 op2, batch_size, (uint8_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
484 ASSERT_EQ(
485 xnn_status_success, xnn_setup_copy_nc_x8(
486 op3, batch_size, (uint8_t*) input.data() + op1->channels * 2, operator_output3.data(),
487 nullptr /* thread pool */));
488 ASSERT_EQ(
489 xnn_status_success, xnn_setup_copy_nc_x8(
490 op4, batch_size, (uint8_t*) input.data() + op1->channels * 3, operator_output4.data(),
491 nullptr /* thread pool */));
492
493 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
494 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
495 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
496 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
497
498 // Call subgraph API.
499 xnn_subgraph_t subgraph = nullptr;
500 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
501 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
502
503 input_id = XNN_INVALID_NODE_ID;
504 ASSERT_EQ(
505 xnn_status_success,
506 xnn_define_quantized_tensor_value(
507 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input_dims.size(), input_dims.data(), nullptr, 0,
508 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
509 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
510
511 output1_id = XNN_INVALID_NODE_ID;
512 ASSERT_EQ(
513 xnn_status_success,
514 xnn_define_quantized_tensor_value(
515 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output1_dims.size(), output1_dims.data(), nullptr, 1,
516 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
517 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
518
519 output2_id = XNN_INVALID_NODE_ID;
520 ASSERT_EQ(
521 xnn_status_success,
522 xnn_define_quantized_tensor_value(
523 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output2_dims.size(), output2_dims.data(), nullptr, 2,
524 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
525 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
526
527 output3_id = XNN_INVALID_NODE_ID;
528 ASSERT_EQ(
529 xnn_status_success,
530 xnn_define_quantized_tensor_value(
531 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output3_dims.size(), output3_dims.data(), nullptr, 3,
532 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
533 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
534
535 output4_id = XNN_INVALID_NODE_ID;
536 ASSERT_EQ(
537 xnn_status_success,
538 xnn_define_quantized_tensor_value(
539 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output4_dims.size(), output4_dims.data(), nullptr, 4,
540 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
541 ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
542
543 ASSERT_EQ(
544 xnn_status_success,
545 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
546
547 xnn_runtime_t runtime = nullptr;
548 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
549 ASSERT_NE(nullptr, runtime);
550 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
551 std::array<xnn_external_value, 5> external = {
552 xnn_external_value{input_id, input.data()},
553 xnn_external_value{output1_id, subgraph_output1.data()},
554 xnn_external_value{output2_id, subgraph_output2.data()},
555 xnn_external_value{output3_id, subgraph_output3.data()},
556 xnn_external_value{output4_id, subgraph_output4.data()},
557 };
558 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
559 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
560
561 ASSERT_EQ(subgraph_output1, operator_output1);
562 ASSERT_EQ(subgraph_output2, operator_output2);
563 ASSERT_EQ(subgraph_output3, operator_output3);
564 ASSERT_EQ(subgraph_output4, operator_output4);
565 }
566
TEST_F(EvenSplit4TestF32,matches_operator_api)567 TEST_F(EvenSplit4TestF32, matches_operator_api)
568 {
569 std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
570 std::fill(operator_output1.begin(), operator_output1.end(), std::nanf(""));
571 std::fill(operator_output2.begin(), operator_output2.end(), std::nanf(""));
572 std::fill(operator_output3.begin(), operator_output3.end(), std::nanf(""));
573 std::fill(operator_output4.begin(), operator_output4.end(), std::nanf(""));
574 std::fill(subgraph_output1.begin(), subgraph_output1.end(), std::nanf(""));
575 std::fill(subgraph_output2.begin(), subgraph_output2.end(), std::nanf(""));
576 std::fill(subgraph_output3.begin(), subgraph_output3.end(), std::nanf(""));
577 std::fill(subgraph_output4.begin(), subgraph_output4.end(), std::nanf(""));
578
579 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
580
581 xnn_operator_t op1 = nullptr;
582 xnn_operator_t op2 = nullptr;
583 xnn_operator_t op3 = nullptr;
584 xnn_operator_t op4 = nullptr;
585
586 // Call operator API.
587 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op1));
588 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
589 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op2));
590 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
591 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op3));
592 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
593 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels, input_stride, channels, /*flags=*/0, &op4));
594 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
595
596 ASSERT_EQ(
597 xnn_status_success,
598 xnn_setup_copy_nc_x32(op1, batch_size, input.data(), operator_output1.data(), nullptr /* thread pool */));
599 ASSERT_EQ(
600 xnn_status_success,
601 xnn_setup_copy_nc_x32(
602 op2, batch_size, (uint32_t*) input.data() + op1->channels, operator_output2.data(), nullptr /* thread pool */));
603 ASSERT_EQ(
604 xnn_status_success, xnn_setup_copy_nc_x32(
605 op3, batch_size, (uint32_t*) input.data() + op1->channels * 2, operator_output3.data(),
606 nullptr /* thread pool */));
607 ASSERT_EQ(
608 xnn_status_success, xnn_setup_copy_nc_x32(
609 op4, batch_size, (uint32_t*) input.data() + op1->channels * 3, operator_output4.data(),
610 nullptr /* thread pool */));
611
612 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
613 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
614 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
615 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
616
617 // Call subgraph API.
618 xnn_subgraph_t subgraph = nullptr;
619 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
620 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
621
622 input_id = XNN_INVALID_NODE_ID;
623 ASSERT_EQ(
624 xnn_status_success, xnn_define_tensor_value(
625 subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, 0,
626 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
627 ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
628
629 output1_id = XNN_INVALID_NODE_ID;
630 ASSERT_EQ(
631 xnn_status_success, xnn_define_tensor_value(
632 subgraph, xnn_datatype_fp32, output1_dims.size(), output1_dims.data(), nullptr, 1,
633 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output1_id));
634 ASSERT_NE(output1_id, XNN_INVALID_NODE_ID);
635
636 output2_id = XNN_INVALID_NODE_ID;
637 ASSERT_EQ(
638 xnn_status_success, xnn_define_tensor_value(
639 subgraph, xnn_datatype_fp32, output2_dims.size(), output2_dims.data(), nullptr, 2,
640 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output2_id));
641 ASSERT_NE(output2_id, XNN_INVALID_NODE_ID);
642
643 output3_id = XNN_INVALID_NODE_ID;
644 ASSERT_EQ(
645 xnn_status_success, xnn_define_tensor_value(
646 subgraph, xnn_datatype_fp32, output3_dims.size(), output3_dims.data(), nullptr, 3,
647 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output3_id));
648 ASSERT_NE(output3_id, XNN_INVALID_NODE_ID);
649
650 output4_id = XNN_INVALID_NODE_ID;
651 ASSERT_EQ(
652 xnn_status_success, xnn_define_tensor_value(
653 subgraph, xnn_datatype_fp32, output4_dims.size(), output4_dims.data(), nullptr, 4,
654 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output4_id));
655 ASSERT_NE(output4_id, XNN_INVALID_NODE_ID);
656
657 ASSERT_EQ(
658 xnn_status_success,
659 xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0));
660
661 xnn_runtime_t runtime = nullptr;
662 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
663 ASSERT_NE(nullptr, runtime);
664 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
665 std::array<xnn_external_value, 5> external = {
666 xnn_external_value{input_id, input.data()},
667 xnn_external_value{output1_id, subgraph_output1.data()},
668 xnn_external_value{output2_id, subgraph_output2.data()},
669 xnn_external_value{output3_id, subgraph_output3.data()},
670 xnn_external_value{output4_id, subgraph_output4.data()},
671 };
672 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
673 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
674
675 ASSERT_EQ(subgraph_output1, operator_output1);
676 ASSERT_EQ(subgraph_output2, operator_output2);
677 ASSERT_EQ(subgraph_output3, operator_output3);
678 ASSERT_EQ(subgraph_output4, operator_output4);
679 }
680