xref: /aosp_15_r20/external/XNNPACK/test/concatenate2.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19 
20 #include <gtest/gtest.h>
21 
22 template <typename T> class Concatenate2Test : public ::testing::Test {
23 protected:
Concatenate2Test()24   Concatenate2Test()
25   {
26     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27     rng = std::mt19937((*random_device)());
28     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30     f32dist = std::uniform_real_distribution<float>();
31     i8dist =
32       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33     u8dist =
34       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36 
37     input1_dims = RandomShape();
38     axis = RandomAxis(input1_dims);
39     input2_dims = RandomShape(input1_dims, axis);
40     output_dims = input1_dims;
41     output_dims[axis] = input1_dims[axis] + input2_dims[axis];
42 
43     input1 = std::vector<T>(NumElements(input1_dims));
44     input2 = std::vector<T>(NumElements(input2_dims));
45     operator_output = std::vector<T>(NumElements(output_dims));
46     subgraph_output = std::vector<T>(NumElements(output_dims));
47 
48     signed_zero_point = i8dist(rng);
49     unsigned_zero_point = u8dist(rng);
50     scale = scale_dist(rng);
51 
52     batch_size = 1;
53     channels_1 = 1;
54     channels_2 = 1;
55     for (size_t i = 0; i < axis; i++) {
56       batch_size *= output_dims[i];
57     }
58 
59     for (size_t i = axis; i < input1_dims.size(); i++) {
60       channels_1 *= input1_dims[i];
61       channels_2 *= input2_dims[i];
62     }
63     output_stride = channels_1 + channels_2;
64   }
65 
RandomShape()66   std::vector<size_t> RandomShape()
67   {
68     std::vector<size_t> dims(shape_dist(rng));
69     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
70     return dims;
71   }
72 
RandomShape(const std::vector<size_t> base_dims,size_t axis)73   std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
74   {
75     auto dims = base_dims;
76     dims[axis] = dim_dist(rng);
77     return dims;
78   }
79 
RandomAxis(const std::vector<size_t> & dims)80   size_t RandomAxis(const std::vector<size_t>& dims)
81   {
82     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
83   }
84 
NumElements(const std::vector<size_t> & dims)85   size_t NumElements(const std::vector<size_t>& dims)
86   {
87     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
88   }
89 
90   std::unique_ptr<std::random_device> random_device;
91   std::mt19937 rng;
92   std::uniform_int_distribution<size_t> shape_dist;
93   std::uniform_int_distribution<size_t> dim_dist;
94   std::uniform_real_distribution<float> f32dist;
95   std::uniform_int_distribution<int32_t> i8dist;
96   std::uniform_int_distribution<int32_t> u8dist;
97   std::uniform_real_distribution<float> scale_dist;
98 
99   uint32_t input1_id;
100   uint32_t input2_id;
101   uint32_t output_id;
102 
103   std::vector<size_t> input1_dims;
104   std::vector<size_t> input2_dims;
105   std::vector<size_t> output_dims;
106 
107   size_t axis;
108   size_t batch_size;
109   size_t channels_1;
110   size_t channels_2;
111   size_t output_stride;
112 
113   int32_t signed_zero_point;
114   int32_t unsigned_zero_point;
115   float scale;
116 
117   std::vector<T> input1;
118   std::vector<T> input2;
119   std::vector<T> operator_output;
120   std::vector<T> subgraph_output;
121 };
122 
123 using Concatenate2TestQS8 = Concatenate2Test<int8_t>;
124 using Concatenate2TestQU8 = Concatenate2Test<uint8_t>;
125 using Concatenate2TestF32 = Concatenate2Test<float>;
126 
TEST_F(Concatenate2TestQS8,define)127 TEST_F(Concatenate2TestQS8, define)
128 {
129   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
130 
131   xnn_subgraph_t subgraph = nullptr;
132   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
133   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
134 
135   input1_id = XNN_INVALID_NODE_ID;
136   ASSERT_EQ(
137     xnn_status_success,
138     xnn_define_quantized_tensor_value(
139       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
140       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
141   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
142 
143   input2_id = XNN_INVALID_NODE_ID;
144   ASSERT_EQ(
145     xnn_status_success,
146     xnn_define_quantized_tensor_value(
147       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
148       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
149   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
150 
151   output_id = XNN_INVALID_NODE_ID;
152   ASSERT_EQ(
153     xnn_status_success,
154     xnn_define_quantized_tensor_value(
155       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
156       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
157   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
158 
159   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
160 
161   ASSERT_EQ(subgraph->num_nodes, 1);
162   const struct xnn_node* node = &subgraph->nodes[0];
163   ASSERT_EQ(node->type, xnn_node_type_concatenate2);
164   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
165   ASSERT_EQ(node->params.concatenate.axis, axis);
166   ASSERT_EQ(node->num_inputs, 2);
167   ASSERT_EQ(node->inputs[0], input1_id);
168   ASSERT_EQ(node->inputs[1], input2_id);
169   ASSERT_EQ(node->num_outputs, 1);
170   ASSERT_EQ(node->outputs[0], output_id);
171   ASSERT_EQ(node->flags, 0);
172 }
173 
TEST_F(Concatenate2TestQU8,define)174 TEST_F(Concatenate2TestQU8, define)
175 {
176   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
177 
178   xnn_subgraph_t subgraph = nullptr;
179   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
180   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
181 
182   input1_id = XNN_INVALID_NODE_ID;
183   ASSERT_EQ(
184     xnn_status_success,
185     xnn_define_quantized_tensor_value(
186       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
187       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
188   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
189 
190   input2_id = XNN_INVALID_NODE_ID;
191   ASSERT_EQ(
192     xnn_status_success,
193     xnn_define_quantized_tensor_value(
194       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
195       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
196   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
197 
198   output_id = XNN_INVALID_NODE_ID;
199   ASSERT_EQ(
200     xnn_status_success,
201     xnn_define_quantized_tensor_value(
202       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
203       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
204   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
205 
206   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
207 
208   ASSERT_EQ(subgraph->num_nodes, 1);
209   const struct xnn_node* node = &subgraph->nodes[0];
210   ASSERT_EQ(node->type, xnn_node_type_concatenate2);
211   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
212   ASSERT_EQ(node->params.concatenate.axis, axis);
213   ASSERT_EQ(node->num_inputs, 2);
214   ASSERT_EQ(node->inputs[0], input1_id);
215   ASSERT_EQ(node->inputs[1], input2_id);
216   ASSERT_EQ(node->num_outputs, 1);
217   ASSERT_EQ(node->outputs[0], output_id);
218   ASSERT_EQ(node->flags, 0);
219 }
220 
TEST_F(Concatenate2TestF32,define)221 TEST_F(Concatenate2TestF32, define)
222 {
223   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
224 
225   xnn_subgraph_t subgraph = nullptr;
226   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
227   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
228 
229   input1_id = XNN_INVALID_NODE_ID;
230   ASSERT_EQ(
231     xnn_status_success, xnn_define_tensor_value(
232                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
233                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
234   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
235 
236   input2_id = XNN_INVALID_NODE_ID;
237   ASSERT_EQ(
238     xnn_status_success, xnn_define_tensor_value(
239                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
240                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
241   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
242 
243   output_id = XNN_INVALID_NODE_ID;
244   ASSERT_EQ(
245     xnn_status_success, xnn_define_tensor_value(
246                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 2,
247                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
248   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
249 
250   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
251 
252   ASSERT_EQ(subgraph->num_nodes, 1);
253   const struct xnn_node* node = &subgraph->nodes[0];
254   ASSERT_EQ(node->type, xnn_node_type_concatenate2);
255   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
256   ASSERT_EQ(node->params.concatenate.axis, axis);
257   ASSERT_EQ(node->num_inputs, 2);
258   ASSERT_EQ(node->inputs[0], input1_id);
259   ASSERT_EQ(node->inputs[1], input2_id);
260   ASSERT_EQ(node->num_outputs, 1);
261   ASSERT_EQ(node->outputs[0], output_id);
262   ASSERT_EQ(node->flags, 0);
263 }
264 
TEST_F(Concatenate2TestQS8,matches_operator_api)265 TEST_F(Concatenate2TestQS8, matches_operator_api)
266 {
267   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
268   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
269   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
270   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
271 
272   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
273 
274   xnn_operator_t op1 = nullptr;
275   xnn_operator_t op2 = nullptr;
276 
277   // Call operator API.
278   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
279   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
280   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
281   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
282 
283   ASSERT_EQ(
284     xnn_status_success,
285     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
286   ASSERT_EQ(
287     xnn_status_success,
288     xnn_setup_copy_nc_x8(
289       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
290 
291   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
292   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
293 
294   // Call subgraph API.
295   xnn_subgraph_t subgraph = nullptr;
296   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
297   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
298 
299   input1_id = XNN_INVALID_NODE_ID;
300   ASSERT_EQ(
301     xnn_status_success,
302     xnn_define_quantized_tensor_value(
303       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
304       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
305   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
306 
307   input2_id = XNN_INVALID_NODE_ID;
308   ASSERT_EQ(
309     xnn_status_success,
310     xnn_define_quantized_tensor_value(
311       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
312       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
313   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
314 
315   output_id = XNN_INVALID_NODE_ID;
316   ASSERT_EQ(
317     xnn_status_success,
318     xnn_define_quantized_tensor_value(
319       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
320       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
321   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
322 
323   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
324 
325   xnn_runtime_t runtime = nullptr;
326   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
327   ASSERT_NE(nullptr, runtime);
328   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
329   std::array<xnn_external_value, 3> external = {
330     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
331     xnn_external_value{output_id, subgraph_output.data()}};
332   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
333   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
334 
335   // Check outputs match.
336   ASSERT_EQ(subgraph_output, operator_output);
337 }
338 
TEST_F(Concatenate2TestQU8,matches_operator_api)339 TEST_F(Concatenate2TestQU8, matches_operator_api)
340 {
341   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
342   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
343   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
344   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
345 
346   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
347 
348   xnn_operator_t op1 = nullptr;
349   xnn_operator_t op2 = nullptr;
350 
351   // Call operator API.
352   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
353   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
354   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
355   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
356 
357   ASSERT_EQ(
358     xnn_status_success,
359     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
360   ASSERT_EQ(
361     xnn_status_success,
362     xnn_setup_copy_nc_x8(
363       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
364 
365   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
366   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
367 
368   // Call subgraph API.
369   xnn_subgraph_t subgraph = nullptr;
370   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
371   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
372 
373   input1_id = XNN_INVALID_NODE_ID;
374   ASSERT_EQ(
375     xnn_status_success,
376     xnn_define_quantized_tensor_value(
377       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
378       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
379   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
380 
381   input2_id = XNN_INVALID_NODE_ID;
382   ASSERT_EQ(
383     xnn_status_success,
384     xnn_define_quantized_tensor_value(
385       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
386       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
387   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
388 
389   output_id = XNN_INVALID_NODE_ID;
390   ASSERT_EQ(
391     xnn_status_success,
392     xnn_define_quantized_tensor_value(
393       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 2,
394       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
395   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
396 
397   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
398 
399   xnn_runtime_t runtime = nullptr;
400   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
401   ASSERT_NE(nullptr, runtime);
402   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
403   std::array<xnn_external_value, 3> external = {
404     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
405     xnn_external_value{output_id, subgraph_output.data()}};
406   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
407   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
408 
409   // Check outputs match.
410   ASSERT_EQ(subgraph_output, operator_output);
411 }
412 
TEST_F(Concatenate2TestF32,matches_operator_api)413 TEST_F(Concatenate2TestF32, matches_operator_api)
414 {
415   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
416   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
417   std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
418   std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
419 
420   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
421 
422   xnn_operator_t op1 = nullptr;
423   xnn_operator_t op2 = nullptr;
424 
425   // Call operator API.
426   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
427   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
428   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
429   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
430 
431   ASSERT_EQ(
432     xnn_status_success,
433     xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
434   ASSERT_EQ(
435     xnn_status_success,
436     xnn_setup_copy_nc_x32(
437       op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
438 
439   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
440   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
441 
442   // Call subgraph API.
443   xnn_subgraph_t subgraph = nullptr;
444   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph));
445   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
446 
447   input1_id = XNN_INVALID_NODE_ID;
448   ASSERT_EQ(
449     xnn_status_success, xnn_define_tensor_value(
450                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
451                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
452   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
453 
454   input2_id = XNN_INVALID_NODE_ID;
455   ASSERT_EQ(
456     xnn_status_success, xnn_define_tensor_value(
457                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
458                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
459   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
460 
461   output_id = XNN_INVALID_NODE_ID;
462   ASSERT_EQ(
463     xnn_status_success, xnn_define_tensor_value(
464                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 2,
465                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
466   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
467 
468   ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0));
469 
470   xnn_runtime_t runtime = nullptr;
471   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
472   ASSERT_NE(nullptr, runtime);
473   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
474   std::array<xnn_external_value, 3> external = {
475     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
476     xnn_external_value{output_id, subgraph_output.data()}};
477   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
478   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
479 
480   // Check outputs match.
481   ASSERT_EQ(subgraph_output, operator_output);
482 }
483