1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19
20 #include <gtest/gtest.h>
21
22 template <typename T> class Concatenate2Test : public ::testing::Test {
23 protected:
Concatenate2Test()24 Concatenate2Test()
25 {
26 random_device = std::unique_ptr<std::random_device>(new std::random_device());
27 rng = std::mt19937((*random_device)());
28 shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29 dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30 f32dist = std::uniform_real_distribution<float>();
31 i8dist =
32 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33 u8dist =
34 std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35 scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36
37 input1_dims = RandomShape();
38 axis = RandomAxis(input1_dims);
39 input2_dims = RandomShape(input1_dims, axis);
40 output_dims = input1_dims;
41 output_dims[axis] = input1_dims[axis] + input2_dims[axis];
42
43 input1 = std::vector<T>(NumElements(input1_dims));
44 input2 = std::vector<T>(NumElements(input2_dims));
45 operator_output = std::vector<T>(NumElements(output_dims));
46 subgraph_output = std::vector<T>(NumElements(output_dims));
47
48 signed_zero_point = i8dist(rng);
49 unsigned_zero_point = u8dist(rng);
50 scale = scale_dist(rng);
51
52 batch_size = 1;
53 channels_1 = 1;
54 channels_2 = 1;
55 for (size_t i = 0; i < axis; i++) {
56 batch_size *= output_dims[i];
57 }
58
59 for (size_t i = axis; i < input1_dims.size(); i++) {
60 channels_1 *= input1_dims[i];
61 channels_2 *= input2_dims[i];
62 }
63 output_stride = channels_1 + channels_2;
64 }
65
RandomShape()66 std::vector<size_t> RandomShape()
67 {
68 std::vector<size_t> dims(shape_dist(rng));
69 std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
70 return dims;
71 }
72
RandomShape(const std::vector<size_t> base_dims,size_t axis)73 std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
74 {
75 auto dims = base_dims;
76 dims[axis] = dim_dist(rng);
77 return dims;
78 }
79
RandomAxis(const std::vector<size_t> & dims)80 size_t RandomAxis(const std::vector<size_t>& dims)
81 {
82 return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
83 }
84
NumElements(const std::vector<size_t> & dims)85 size_t NumElements(const std::vector<size_t>& dims)
86 {
87 return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
88 }
89
90 std::unique_ptr<std::random_device> random_device;
91 std::mt19937 rng;
92 std::uniform_int_distribution<size_t> shape_dist;
93 std::uniform_int_distribution<size_t> dim_dist;
94 std::uniform_real_distribution<float> f32dist;
95 std::uniform_int_distribution<int32_t> i8dist;
96 std::uniform_int_distribution<int32_t> u8dist;
97 std::uniform_real_distribution<float> scale_dist;
98
99 uint32_t input1_id;
100 uint32_t input2_id;
101 uint32_t output_id;
102
103 std::vector<size_t> input1_dims;
104 std::vector<size_t> input2_dims;
105 std::vector<size_t> output_dims;
106
107 size_t axis;
108 size_t batch_size;
109 size_t channels_1;
110 size_t channels_2;
111 size_t output_stride;
112
113 int32_t signed_zero_point;
114 int32_t unsigned_zero_point;
115 float scale;
116
117 std::vector<T> input1;
118 std::vector<T> input2;
119 std::vector<T> operator_output;
120 std::vector<T> subgraph_output;
121 };
122
123 using Concatenate2TestQS8 = Concatenate2Test<int8_t>;
124 using Concatenate2TestQU8 = Concatenate2Test<uint8_t>;
125 using Concatenate2TestF32 = Concatenate2Test<float>;
126
TEST_F(Concatenate2TestQS8,define)127 TEST_F(Concatenate2TestQS8, define)
128 {
129 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
130
131 xnn_subgraph_t subgraph = nullptr;
132 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
133 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
134
135 input1_id = XNN_INVALID_NODE_ID;
136 ASSERT_EQ(
137 xnn_status_success,
138 xnn_define_quantized_tensor_value(
139 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
140 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
141 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
142
143 input2_id = XNN_INVALID_NODE_ID;
144 ASSERT_EQ(
145 xnn_status_success,
146 xnn_define_quantized_tensor_value(
147 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
148 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
149 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
150
151 output_id = XNN_INVALID_NODE_ID;
152 ASSERT_EQ(
153 xnn_status_success,
154 xnn_define_quantized_tensor_value(
155 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
156 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
157 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
158
159 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
160
161 ASSERT_EQ(subgraph->num_nodes, 1);
162 const struct xnn_node* node = &subgraph->nodes[0];
163 ASSERT_EQ(node->type, xnn_node_type_concatenate2);
164 ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
165 ASSERT_EQ(node->params.concatenate.axis, axis);
166 ASSERT_EQ(node->num_inputs, 2);
167 ASSERT_EQ(node->inputs[0], input1_id);
168 ASSERT_EQ(node->inputs[1], input2_id);
169 ASSERT_EQ(node->num_outputs, 1);
170 ASSERT_EQ(node->outputs[0], output_id);
171 ASSERT_EQ(node->flags, 0);
172 }
173
TEST_F(Concatenate2TestQU8,define)174 TEST_F(Concatenate2TestQU8, define)
175 {
176 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
177
178 xnn_subgraph_t subgraph = nullptr;
179 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
180 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
181
182 input1_id = XNN_INVALID_NODE_ID;
183 ASSERT_EQ(
184 xnn_status_success,
185 xnn_define_quantized_tensor_value(
186 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
187 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
188 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
189
190 input2_id = XNN_INVALID_NODE_ID;
191 ASSERT_EQ(
192 xnn_status_success,
193 xnn_define_quantized_tensor_value(
194 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
195 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
196 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
197
198 output_id = XNN_INVALID_NODE_ID;
199 ASSERT_EQ(
200 xnn_status_success,
201 xnn_define_quantized_tensor_value(
202 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
203 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
204 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
205
206 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
207
208 ASSERT_EQ(subgraph->num_nodes, 1);
209 const struct xnn_node* node = &subgraph->nodes[0];
210 ASSERT_EQ(node->type, xnn_node_type_concatenate2);
211 ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
212 ASSERT_EQ(node->params.concatenate.axis, axis);
213 ASSERT_EQ(node->num_inputs, 2);
214 ASSERT_EQ(node->inputs[0], input1_id);
215 ASSERT_EQ(node->inputs[1], input2_id);
216 ASSERT_EQ(node->num_outputs, 1);
217 ASSERT_EQ(node->outputs[0], output_id);
218 ASSERT_EQ(node->flags, 0);
219 }
220
TEST_F(Concatenate2TestF32,define)221 TEST_F(Concatenate2TestF32, define)
222 {
223 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
224
225 xnn_subgraph_t subgraph = nullptr;
226 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
227 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
228
229 input1_id = XNN_INVALID_NODE_ID;
230 ASSERT_EQ(
231 xnn_status_success, xnn_define_tensor_value(
232 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
233 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
234 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
235
236 input2_id = XNN_INVALID_NODE_ID;
237 ASSERT_EQ(
238 xnn_status_success, xnn_define_tensor_value(
239 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
240 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
241 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
242
243 output_id = XNN_INVALID_NODE_ID;
244 ASSERT_EQ(
245 xnn_status_success, xnn_define_tensor_value(
246 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 2,
247 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
248 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
249
250 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
251
252 ASSERT_EQ(subgraph->num_nodes, 1);
253 const struct xnn_node* node = &subgraph->nodes[0];
254 ASSERT_EQ(node->type, xnn_node_type_concatenate2);
255 ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
256 ASSERT_EQ(node->params.concatenate.axis, axis);
257 ASSERT_EQ(node->num_inputs, 2);
258 ASSERT_EQ(node->inputs[0], input1_id);
259 ASSERT_EQ(node->inputs[1], input2_id);
260 ASSERT_EQ(node->num_outputs, 1);
261 ASSERT_EQ(node->outputs[0], output_id);
262 ASSERT_EQ(node->flags, 0);
263 }
264
TEST_F(Concatenate2TestQS8,matches_operator_api)265 TEST_F(Concatenate2TestQS8, matches_operator_api)
266 {
267 std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
268 std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
269 std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
270 std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
271
272 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
273
274 xnn_operator_t op1 = nullptr;
275 xnn_operator_t op2 = nullptr;
276
277 // Call operator API.
278 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
279 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
280 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
281 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
282
283 ASSERT_EQ(
284 xnn_status_success,
285 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
286 ASSERT_EQ(
287 xnn_status_success,
288 xnn_setup_copy_nc_x8(
289 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
290
291 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
292 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
293
294 // Call subgraph API.
295 xnn_subgraph_t subgraph = nullptr;
296 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
297 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
298
299 input1_id = XNN_INVALID_NODE_ID;
300 ASSERT_EQ(
301 xnn_status_success,
302 xnn_define_quantized_tensor_value(
303 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
304 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
305 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
306
307 input2_id = XNN_INVALID_NODE_ID;
308 ASSERT_EQ(
309 xnn_status_success,
310 xnn_define_quantized_tensor_value(
311 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
312 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
313 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
314
315 output_id = XNN_INVALID_NODE_ID;
316 ASSERT_EQ(
317 xnn_status_success,
318 xnn_define_quantized_tensor_value(
319 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
320 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
321 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
322
323 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
324
325 xnn_runtime_t runtime = nullptr;
326 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
327 ASSERT_NE(nullptr, runtime);
328 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
329 std::array<xnn_external_value, 3> external = {
330 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
331 xnn_external_value{output_id, subgraph_output.data()}};
332 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
333 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
334
335 // Check outputs match.
336 ASSERT_EQ(subgraph_output, operator_output);
337 }
338
TEST_F(Concatenate2TestQU8,matches_operator_api)339 TEST_F(Concatenate2TestQU8, matches_operator_api)
340 {
341 std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
342 std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
343 std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
344 std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
345
346 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
347
348 xnn_operator_t op1 = nullptr;
349 xnn_operator_t op2 = nullptr;
350
351 // Call operator API.
352 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
353 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
354 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
355 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
356
357 ASSERT_EQ(
358 xnn_status_success,
359 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
360 ASSERT_EQ(
361 xnn_status_success,
362 xnn_setup_copy_nc_x8(
363 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
364
365 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
366 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
367
368 // Call subgraph API.
369 xnn_subgraph_t subgraph = nullptr;
370 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
371 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
372
373 input1_id = XNN_INVALID_NODE_ID;
374 ASSERT_EQ(
375 xnn_status_success,
376 xnn_define_quantized_tensor_value(
377 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
378 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
379 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
380
381 input2_id = XNN_INVALID_NODE_ID;
382 ASSERT_EQ(
383 xnn_status_success,
384 xnn_define_quantized_tensor_value(
385 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
386 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
387 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
388
389 output_id = XNN_INVALID_NODE_ID;
390 ASSERT_EQ(
391 xnn_status_success,
392 xnn_define_quantized_tensor_value(
393 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
394 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
395 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
396
397 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
398
399 xnn_runtime_t runtime = nullptr;
400 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
401 ASSERT_NE(nullptr, runtime);
402 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
403 std::array<xnn_external_value, 3> external = {
404 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
405 xnn_external_value{output_id, subgraph_output.data()}};
406 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
407 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
408
409 // Check outputs match.
410 ASSERT_EQ(subgraph_output, operator_output);
411 }
412
TEST_F(Concatenate2TestF32,matches_operator_api)413 TEST_F(Concatenate2TestF32, matches_operator_api)
414 {
415 std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
416 std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
417 std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
418 std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
419
420 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
421
422 xnn_operator_t op1 = nullptr;
423 xnn_operator_t op2 = nullptr;
424
425 // Call operator API.
426 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
427 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
428 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
429 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
430
431 ASSERT_EQ(
432 xnn_status_success,
433 xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
434 ASSERT_EQ(
435 xnn_status_success,
436 xnn_setup_copy_nc_x32(
437 op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
438
439 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
440 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
441
442 // Call subgraph API.
443 xnn_subgraph_t subgraph = nullptr;
444 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
445 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
446
447 input1_id = XNN_INVALID_NODE_ID;
448 ASSERT_EQ(
449 xnn_status_success, xnn_define_tensor_value(
450 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
451 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
452 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
453
454 input2_id = XNN_INVALID_NODE_ID;
455 ASSERT_EQ(
456 xnn_status_success, xnn_define_tensor_value(
457 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
458 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
459 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
460
461 output_id = XNN_INVALID_NODE_ID;
462 ASSERT_EQ(
463 xnn_status_success, xnn_define_tensor_value(
464 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 2,
465 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
466 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
467
468 ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
469
470 xnn_runtime_t runtime = nullptr;
471 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
472 ASSERT_NE(nullptr, runtime);
473 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
474 std::array<xnn_external_value, 3> external = {
475 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
476 xnn_external_value{output_id, subgraph_output.data()}};
477 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
478 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
479
480 // Check outputs match.
481 ASSERT_EQ(subgraph_output, operator_output);
482 }
483