xref: /aosp_15_r20/external/XNNPACK/test/max-pooling-2d.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>  // For std::generate, std::min.
7 #include <array>      // For std::array.
8 #include <cmath>      // For std::lrintf.
9 #include <cstddef>    // For size_t.
10 #include <cstdint>    // For uint32_t.
11 #include <limits>     // For std::numeric_limits.
12 #include <memory>     // For std::unique_ptr.
13 #include <random>     // For std::random_device, std::mt19937, std::uniform_real_distribution.
14 #include <vector>     // For std::vector.
15 
16 #include <xnnpack.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/requantization.h>
19 #include <xnnpack/subgraph.h>
20 
21 #include <gtest/gtest.h>
22 
23 template <class T> class MaxPooling2DTestBase : public ::testing::Test {
24 protected:
MaxPooling2DTestBase()25   MaxPooling2DTestBase()
26   {
27     random_device = std::unique_ptr<std::random_device>(new std::random_device());
28     rng = std::mt19937((*random_device)());
29     input_size_dist = std::uniform_int_distribution<uint32_t>(10, 15);
30     kernel_size_dist = std::uniform_int_distribution<uint32_t>(2, 5);
31     f32dist = std::uniform_real_distribution<float>();
32     scale_dist = std::uniform_real_distribution<float>(1.0f, 5.0f);
33     i32dist = std::uniform_int_distribution<int32_t>(-10000, 10000);
34     dilation_dist = std::uniform_int_distribution<uint32_t>(1, 2);
35     i8dist =
36       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
37     u8dist =
38       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
39 
40     batch_size = input_size_dist(rng);
41     input_height = input_size_dist(rng);
42     input_width = input_size_dist(rng);
43     channels = input_size_dist(rng);
44     pooling_height = kernel_size_dist(rng);
45     pooling_width = kernel_size_dist(rng);
46     padding_top = std::uniform_int_distribution<uint32_t>(0, pooling_height - 1)(rng);
47     padding_bottom = std::uniform_int_distribution<uint32_t>(0, pooling_height - 1)(rng);
48     padding_left = std::uniform_int_distribution<uint32_t>(0, pooling_width - 1)(rng);
49     padding_right = std::uniform_int_distribution<uint32_t>(0, pooling_width - 1)(rng);
50     dilation_height = dilation_dist(rng);
51     dilation_width = dilation_height;
52     // stride dimension must be <= filter dimension
53     stride_height = std::uniform_int_distribution<uint32_t>(1, pooling_height)(rng);
54     stride_width = std::uniform_int_distribution<uint32_t>(1, pooling_width)(rng);
55     output_min = -std::numeric_limits<float>::infinity();
56     output_max = std::numeric_limits<float>::infinity();
57     output_height = xnn_compute_convolution_output_dimension(
58       padding_top + input_height + padding_bottom, pooling_height, dilation_height, stride_height);
59     output_width = xnn_compute_convolution_output_dimension(
60       padding_left + input_width + padding_right, pooling_width, dilation_width, stride_width);
61 
62     input_dims = {{batch_size, input_height, input_width, channels}};
63     output_dims = {{batch_size, output_height, output_width, channels}};
64 
65     input = std::vector<T>(XNN_EXTRA_BYTES / sizeof(T) + batch_size * input_height * input_width * channels);
66     operator_output =
67       std::vector<T>(XNN_EXTRA_BYTES / sizeof(T) + batch_size * output_height * output_width * channels);
68     subgraph_output =
69       std::vector<T>(XNN_EXTRA_BYTES / sizeof(T) + batch_size * output_height * output_width * channels);
70   }
71 
72   std::unique_ptr<std::random_device> random_device;
73   std::mt19937 rng;
74   std::uniform_int_distribution<uint32_t> input_size_dist;
75   std::uniform_int_distribution<uint32_t> kernel_size_dist;
76   std::uniform_int_distribution<int32_t> i32dist;
77   std::uniform_real_distribution<float> f32dist;
78   std::uniform_real_distribution<float> scale_dist;
79   std::uniform_int_distribution<uint32_t> dilation_dist;
80   std::uniform_int_distribution<int32_t> i8dist;
81   std::uniform_int_distribution<int32_t> u8dist;
82 
83   uint32_t padding_top;
84   uint32_t padding_right;
85   uint32_t padding_bottom;
86   uint32_t padding_left;
87   uint32_t batch_size;
88   uint32_t input_height;
89   uint32_t input_width;
90   uint32_t pooling_height;
91   uint32_t pooling_width;
92   uint32_t stride_height;
93   uint32_t stride_width;
94   uint32_t dilation_height;
95   uint32_t dilation_width;
96   uint32_t channels;
97   float output_min;
98   float output_max;
99   uint32_t output_height;
100   uint32_t output_width;
101 
102   std::array<size_t, 4> input_dims;
103   std::array<size_t, 4> output_dims;
104 
105   std::vector<T> input;
106   std::vector<T> operator_output;
107   std::vector<T> subgraph_output;
108 };
109 
110 using MaxPooling2DTestQS8 = MaxPooling2DTestBase<int8_t>;
111 using MaxPooling2DTestQU8 = MaxPooling2DTestBase<uint8_t>;
112 using MaxPooling2DTestF32 = MaxPooling2DTestBase<float>;
113 
TEST_F(MaxPooling2DTestQS8,define)114 TEST_F(MaxPooling2DTestQS8, define)
115 {
116   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
117 
118   xnn_subgraph_t subgraph = nullptr;
119   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
120   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
121 
122   uint32_t input_id = XNN_INVALID_NODE_ID;
123   ASSERT_EQ(
124     xnn_status_success, xnn_define_quantized_tensor_value(
125                           subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr,
126                           /*external_id=*/0, /*flags=*/0, &input_id));
127   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
128 
129   uint32_t output_id = XNN_INVALID_NODE_ID;
130   ASSERT_EQ(
131     xnn_status_success, xnn_define_quantized_tensor_value(
132                           subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr,
133                           /*external_id=*/1, /*flags=*/0, &output_id));
134   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
135 
136   ASSERT_EQ(
137     xnn_status_success,
138     xnn_define_max_pooling_2d(
139       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
140       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0));
141 
142   ASSERT_EQ(subgraph->num_nodes, 1);
143   const struct xnn_node* node = &subgraph->nodes[0];
144   ASSERT_EQ(node->type, xnn_node_type_max_pooling_2d);
145   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
146   ASSERT_EQ(node->params.pooling_2d.padding_top, padding_top);
147   ASSERT_EQ(node->params.pooling_2d.padding_right, padding_right);
148   ASSERT_EQ(node->params.pooling_2d.padding_bottom, padding_bottom);
149   ASSERT_EQ(node->params.pooling_2d.padding_left, padding_left);
150   ASSERT_EQ(node->params.pooling_2d.pooling_height, pooling_height);
151   ASSERT_EQ(node->params.pooling_2d.pooling_width, pooling_width);
152   ASSERT_EQ(node->params.pooling_2d.stride_height, stride_height);
153   ASSERT_EQ(node->params.pooling_2d.stride_width, stride_width);
154   ASSERT_EQ(node->params.pooling_2d.dilation_height, dilation_height);
155   ASSERT_EQ(node->params.pooling_2d.dilation_width, dilation_width);
156   ASSERT_EQ(node->activation.output_min, output_min);
157   ASSERT_EQ(node->activation.output_max, output_max);
158   ASSERT_EQ(node->num_inputs, 1);
159   ASSERT_EQ(node->inputs[0], input_id);
160   ASSERT_EQ(node->num_outputs, 1);
161   ASSERT_EQ(node->outputs[0], output_id);
162   ASSERT_EQ(node->flags, 0);
163 }
164 
TEST_F(MaxPooling2DTestQU8,define)165 TEST_F(MaxPooling2DTestQU8, define)
166 {
167   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
168 
169   xnn_subgraph_t subgraph = nullptr;
170   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
171   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
172 
173   uint32_t input_id = XNN_INVALID_NODE_ID;
174   ASSERT_EQ(
175     xnn_status_success, xnn_define_quantized_tensor_value(
176                           subgraph, xnn_datatype_quint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr,
177                           /*external_id=*/0, /*flags=*/0, &input_id));
178   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
179 
180   uint32_t output_id = XNN_INVALID_NODE_ID;
181   ASSERT_EQ(
182     xnn_status_success, xnn_define_quantized_tensor_value(
183                           subgraph, xnn_datatype_quint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr,
184                           /*external_id=*/1, /*flags=*/0, &output_id));
185   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
186 
187   ASSERT_EQ(
188     xnn_status_success,
189     xnn_define_max_pooling_2d(
190       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
191       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id,
192       /*flags=*/0));
193 
194   ASSERT_EQ(subgraph->num_nodes, 1);
195   const struct xnn_node* node = &subgraph->nodes[0];
196   ASSERT_EQ(node->type, xnn_node_type_max_pooling_2d);
197   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
198   ASSERT_EQ(node->params.pooling_2d.padding_top, padding_top);
199   ASSERT_EQ(node->params.pooling_2d.padding_right, padding_right);
200   ASSERT_EQ(node->params.pooling_2d.padding_bottom, padding_bottom);
201   ASSERT_EQ(node->params.pooling_2d.padding_left, padding_left);
202   ASSERT_EQ(node->params.pooling_2d.pooling_height, pooling_height);
203   ASSERT_EQ(node->params.pooling_2d.pooling_width, pooling_width);
204   ASSERT_EQ(node->params.pooling_2d.stride_height, stride_height);
205   ASSERT_EQ(node->params.pooling_2d.stride_width, stride_width);
206   ASSERT_EQ(node->params.pooling_2d.dilation_height, dilation_height);
207   ASSERT_EQ(node->params.pooling_2d.dilation_width, dilation_width);
208   ASSERT_EQ(node->activation.output_min, output_min);
209   ASSERT_EQ(node->activation.output_max, output_max);
210   ASSERT_EQ(node->num_inputs, 1);
211   ASSERT_EQ(node->inputs[0], input_id);
212   ASSERT_EQ(node->num_outputs, 1);
213   ASSERT_EQ(node->outputs[0], output_id);
214   ASSERT_EQ(node->flags, 0);
215 }
216 
TEST_F(MaxPooling2DTestF32,define)217 TEST_F(MaxPooling2DTestF32, define)
218 {
219   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
220 
221   xnn_subgraph_t subgraph = nullptr;
222   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
223   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
224 
225   uint32_t input_id = XNN_INVALID_NODE_ID;
226   ASSERT_EQ(
227     xnn_status_success, xnn_define_tensor_value(
228                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr,
229                           /*external_id=*/0, /*flags=*/0, &input_id));
230   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
231 
232   uint32_t output_id = XNN_INVALID_NODE_ID;
233   ASSERT_EQ(
234     xnn_status_success, xnn_define_tensor_value(
235                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
236                           /*external_id=*/1, /*flags=*/0, &output_id));
237   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
238 
239   ASSERT_EQ(
240     xnn_status_success,
241     xnn_define_max_pooling_2d(
242       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
243       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0));
244 
245   ASSERT_EQ(subgraph->num_nodes, 1);
246   const struct xnn_node* node = &subgraph->nodes[0];
247   ASSERT_EQ(node->type, xnn_node_type_max_pooling_2d);
248   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
249   ASSERT_EQ(node->params.pooling_2d.padding_top, padding_top);
250   ASSERT_EQ(node->params.pooling_2d.padding_right, padding_right);
251   ASSERT_EQ(node->params.pooling_2d.padding_bottom, padding_bottom);
252   ASSERT_EQ(node->params.pooling_2d.padding_left, padding_left);
253   ASSERT_EQ(node->params.pooling_2d.pooling_height, pooling_height);
254   ASSERT_EQ(node->params.pooling_2d.pooling_width, pooling_width);
255   ASSERT_EQ(node->params.pooling_2d.stride_height, stride_height);
256   ASSERT_EQ(node->params.pooling_2d.stride_width, stride_width);
257   ASSERT_EQ(node->params.pooling_2d.dilation_height, dilation_height);
258   ASSERT_EQ(node->params.pooling_2d.dilation_width, dilation_width);
259   ASSERT_EQ(node->activation.output_min, output_min);
260   ASSERT_EQ(node->activation.output_max, output_max);
261   ASSERT_EQ(node->num_inputs, 1);
262   ASSERT_EQ(node->inputs[0], input_id);
263   ASSERT_EQ(node->num_outputs, 1);
264   ASSERT_EQ(node->outputs[0], output_id);
265   ASSERT_EQ(node->flags, 0);
266 }
267 
TEST_F(MaxPooling2DTestQS8,matches_operator_api)268 TEST_F(MaxPooling2DTestQS8, matches_operator_api)
269 {
270   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
271   std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
272   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
273   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
274   const int8_t input_zero_point = i8dist(rng);
275   const float input_scale = scale_dist(rng);
276   const int8_t output_zero_point = input_zero_point;
277   const float output_scale = input_scale;
278   const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point);
279   const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point);
280 
281   // Call operator API.
282   xnn_operator_t op = nullptr;
283   const xnn_status status = xnn_create_max_pooling2d_nhwc_s8(
284     padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
285     stride_width, dilation_height, dilation_width, channels, channels, channels, quantized_output_min,
286     quantized_output_max, /*flags=*/0, &op);
287   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
288 
289   if (status == xnn_status_unsupported_hardware) {
290     GTEST_SKIP();
291   }
292 
293   ASSERT_EQ(xnn_status_success, status);
294   ASSERT_NE(nullptr, op);
295   ASSERT_EQ(
296     xnn_status_success, xnn_setup_max_pooling2d_nhwc_s8(
297                           op, batch_size, input_height, input_width, input.data(), operator_output.data(),
298                           /*threadpool=*/nullptr));
299 
300   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
301 
302   // Call subgraph API.
303   xnn_subgraph_t subgraph = nullptr;
304   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
305   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
306 
307   uint32_t input_id = XNN_INVALID_NODE_ID;
308   ASSERT_EQ(
309     xnn_status_success, xnn_define_quantized_tensor_value(
310                           subgraph, xnn_datatype_qint8, input_zero_point, input_scale, input_dims.size(),
311                           input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
312   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
313 
314   uint32_t output_id = XNN_INVALID_NODE_ID;
315   ASSERT_EQ(
316     xnn_status_success, xnn_define_quantized_tensor_value(
317                           subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(),
318                           output_dims.data(), nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
319   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
320   ASSERT_EQ(
321     xnn_status_success,
322     xnn_define_max_pooling_2d(
323       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
324       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0));
325 
326   xnn_runtime_t runtime = nullptr;
327   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
328   ASSERT_NE(nullptr, runtime);
329   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
330   std::array<xnn_external_value, 2> external = {
331     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
332   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
333   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
334 
335   for (size_t i = 0; i < batch_size * output_height * output_width * channels; i++) {
336     ASSERT_EQ(subgraph_output[i], operator_output[i]);
337   }
338 }
339 
TEST_F(MaxPooling2DTestQU8,matches_operator_api)340 TEST_F(MaxPooling2DTestQU8, matches_operator_api)
341 {
342   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
343   std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
344   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
345   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
346   const uint8_t input_zero_point = u8dist(rng);
347   const float input_scale = scale_dist(rng);
348   const uint8_t output_zero_point = input_zero_point;
349   const float output_scale = input_scale;
350   const uint8_t quantized_output_min = xnn_qu8_quantize(output_min, output_scale, output_zero_point);
351   const uint8_t quantized_output_max = xnn_qu8_quantize(output_max, output_scale, output_zero_point);
352 
353   // Call operator API.
354   xnn_operator_t op = nullptr;
355   const xnn_status status = xnn_create_max_pooling2d_nhwc_u8(
356     padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
357     stride_width, dilation_height, dilation_width, channels, channels, channels, quantized_output_min,
358     quantized_output_max, /*flags=*/0, &op);
359   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
360 
361   if (status == xnn_status_unsupported_hardware) {
362     GTEST_SKIP();
363   }
364 
365   ASSERT_EQ(xnn_status_success, status);
366   ASSERT_NE(nullptr, op);
367   ASSERT_EQ(
368     xnn_status_success, xnn_setup_max_pooling2d_nhwc_u8(
369                           op, batch_size, input_height, input_width, input.data(), operator_output.data(),
370                           /*threadpool=*/nullptr));
371 
372   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
373 
374   // Call subgraph API.
375   xnn_subgraph_t subgraph = nullptr;
376   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
377   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
378 
379   uint32_t input_id = XNN_INVALID_NODE_ID;
380   ASSERT_EQ(
381     xnn_status_success, xnn_define_quantized_tensor_value(
382                           subgraph, xnn_datatype_quint8, input_zero_point, input_scale, input_dims.size(),
383                           input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
384   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
385 
386   uint32_t output_id = XNN_INVALID_NODE_ID;
387   ASSERT_EQ(
388     xnn_status_success, xnn_define_quantized_tensor_value(
389                           subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(),
390                           output_dims.data(), nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
391   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
392   ASSERT_EQ(
393     xnn_status_success,
394     xnn_define_max_pooling_2d(
395       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
396       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0));
397 
398   xnn_runtime_t runtime = nullptr;
399   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
400   ASSERT_NE(nullptr, runtime);
401   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
402   std::array<xnn_external_value, 2> external = {
403     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
404   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
405   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
406 
407   for (size_t i = 0; i < batch_size * output_height * output_width * channels; i++) {
408     ASSERT_EQ(subgraph_output[i], operator_output[i]);
409   }
410 }
411 
TEST_F(MaxPooling2DTestF32,matches_operator_api)412 TEST_F(MaxPooling2DTestF32, matches_operator_api)
413 {
414   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
415   std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
416   std::fill(operator_output.begin(), operator_output.end(), nanf(""));
417   std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));
418 
419   // Call operator API.
420   xnn_operator_t op = nullptr;
421   const xnn_status status = xnn_create_max_pooling2d_nhwc_f32(
422     padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
423     stride_width, dilation_height, dilation_width, channels, channels, channels, output_min, output_max, /*flags=*/0,
424     &op);
425   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
426 
427   if (status == xnn_status_unsupported_hardware) {
428     GTEST_SKIP();
429   }
430 
431   ASSERT_EQ(xnn_status_success, status);
432   ASSERT_NE(nullptr, op);
433   ASSERT_EQ(
434     xnn_status_success, xnn_setup_max_pooling2d_nhwc_f32(
435                           op, batch_size, input_height, input_width, input.data(), operator_output.data(),
436                           /*threadpool=*/nullptr));
437 
438   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
439 
440   // Call subgraph API.
441   xnn_subgraph_t subgraph = nullptr;
442   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
443   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
444 
445   uint32_t input_id = XNN_INVALID_NODE_ID;
446   ASSERT_EQ(
447     xnn_status_success, xnn_define_tensor_value(
448                           subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr,
449                           /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
450   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
451 
452   uint32_t output_id = XNN_INVALID_NODE_ID;
453   ASSERT_EQ(
454     xnn_status_success, xnn_define_tensor_value(
455                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
456                           /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
457   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
458   ASSERT_EQ(
459     xnn_status_success,
460     xnn_define_max_pooling_2d(
461       subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, stride_height,
462       stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0));
463 
464   xnn_runtime_t runtime = nullptr;
465   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
466   ASSERT_NE(nullptr, runtime);
467   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
468   std::array<xnn_external_value, 2> external = {
469     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
470   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
471   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
472 
473   for (size_t i = 0; i < batch_size * output_height * output_width * channels; i++) {
474     ASSERT_EQ(subgraph_output[i], operator_output[i]);
475   }
476 }
477