xref: /aosp_15_r20/external/XNNPACK/test/unpooling-2d.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>  // For std::generate, std::min.
7*4bdc9457SAndroid Build Coastguard Worker #include <array>      // For std::array.
8*4bdc9457SAndroid Build Coastguard Worker #include <cmath>      // For std::lrintf.
9*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>    // For size_t.
10*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>    // For uint32_t.
11*4bdc9457SAndroid Build Coastguard Worker #include <limits>     // For std::numeric_limits.
12*4bdc9457SAndroid Build Coastguard Worker #include <memory>     // For std::unique_ptr.
13*4bdc9457SAndroid Build Coastguard Worker #include <random>     // For std::random_device, std::mt19937, std::uniform_real_distribution.
14*4bdc9457SAndroid Build Coastguard Worker #include <vector>     // For std::vector.
15*4bdc9457SAndroid Build Coastguard Worker 
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker template <class T, class BiasType = T> class Unpooling2DTestBase : public ::testing::Test {
24*4bdc9457SAndroid Build Coastguard Worker protected:
Unpooling2DTestBase()25*4bdc9457SAndroid Build Coastguard Worker   Unpooling2DTestBase()
26*4bdc9457SAndroid Build Coastguard Worker   {
27*4bdc9457SAndroid Build Coastguard Worker     random_device = std::unique_ptr<std::random_device>(new std::random_device());
28*4bdc9457SAndroid Build Coastguard Worker     rng = std::mt19937((*random_device)());
29*4bdc9457SAndroid Build Coastguard Worker     input_size_dist = std::uniform_int_distribution<uint32_t>(10, 15);
30*4bdc9457SAndroid Build Coastguard Worker     kernel_size_dist = std::uniform_int_distribution<uint32_t>(1, 5);
31*4bdc9457SAndroid Build Coastguard Worker     stride_dist = std::uniform_int_distribution<uint32_t>(1, 3);
32*4bdc9457SAndroid Build Coastguard Worker     f32dist = std::uniform_real_distribution<float>(0.1f, 1.0f);
33*4bdc9457SAndroid Build Coastguard Worker     scale_dist = std::uniform_real_distribution<float>(1.0f, 5.0f);
34*4bdc9457SAndroid Build Coastguard Worker     i32dist = std::uniform_int_distribution<int32_t>(-10000, 10000);
35*4bdc9457SAndroid Build Coastguard Worker     u32dist = std::uniform_int_distribution<uint32_t>();
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker     batch_size = input_size_dist(rng);
38*4bdc9457SAndroid Build Coastguard Worker     input_height = input_size_dist(rng);
39*4bdc9457SAndroid Build Coastguard Worker     input_width = input_size_dist(rng);
40*4bdc9457SAndroid Build Coastguard Worker     pooling_height = 2;
41*4bdc9457SAndroid Build Coastguard Worker     pooling_width = 2;
42*4bdc9457SAndroid Build Coastguard Worker     channels = input_size_dist(rng);
43*4bdc9457SAndroid Build Coastguard Worker     output_height = xnn_compute_unpooling_output_dimension(input_height, padding_top + padding_bottom, pooling_height);
44*4bdc9457SAndroid Build Coastguard Worker     output_width = xnn_compute_unpooling_output_dimension(input_width, padding_left + padding_right, pooling_width);
45*4bdc9457SAndroid Build Coastguard Worker 
46*4bdc9457SAndroid Build Coastguard Worker     index_dist = std::uniform_int_distribution<uint32_t>(0, pooling_height * pooling_width - 1);
47*4bdc9457SAndroid Build Coastguard Worker 
48*4bdc9457SAndroid Build Coastguard Worker     input_value_dims = {{batch_size, input_height, input_width, channels}};
49*4bdc9457SAndroid Build Coastguard Worker     input_index_dims = {{batch_size, input_height, input_width, channels}};
50*4bdc9457SAndroid Build Coastguard Worker     output_dims = {{batch_size, output_height, output_width, channels}};
51*4bdc9457SAndroid Build Coastguard Worker 
52*4bdc9457SAndroid Build Coastguard Worker     input = std::vector<T>(XNN_EXTRA_BYTES / sizeof(T) + batch_size * input_height * input_width * channels);
53*4bdc9457SAndroid Build Coastguard Worker     input_index = std::vector<T>(batch_size * input_height * input_width * channels);
54*4bdc9457SAndroid Build Coastguard Worker     operator_output = std::vector<T>(batch_size * output_height * output_width * channels);
55*4bdc9457SAndroid Build Coastguard Worker     subgraph_output = std::vector<T>(batch_size * output_height * output_width * channels);
56*4bdc9457SAndroid Build Coastguard Worker   }
57*4bdc9457SAndroid Build Coastguard Worker 
58*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<std::random_device> random_device;
59*4bdc9457SAndroid Build Coastguard Worker   std::mt19937 rng;
60*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> input_size_dist;
61*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> kernel_size_dist;
62*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> stride_dist;
63*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<int32_t> i32dist;
64*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> u32dist;
65*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> index_dist;
66*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> f32dist;
67*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> scale_dist;
68*4bdc9457SAndroid Build Coastguard Worker 
69*4bdc9457SAndroid Build Coastguard Worker   const uint32_t padding_top = 0;
70*4bdc9457SAndroid Build Coastguard Worker   const uint32_t padding_right = 0;
71*4bdc9457SAndroid Build Coastguard Worker   const uint32_t padding_bottom = 0;
72*4bdc9457SAndroid Build Coastguard Worker   const uint32_t padding_left = 0;
73*4bdc9457SAndroid Build Coastguard Worker   uint32_t batch_size;
74*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_height;
75*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_width;
76*4bdc9457SAndroid Build Coastguard Worker   uint32_t kernel_height;
77*4bdc9457SAndroid Build Coastguard Worker   uint32_t kernel_width;
78*4bdc9457SAndroid Build Coastguard Worker   uint32_t pooling_height;
79*4bdc9457SAndroid Build Coastguard Worker   uint32_t pooling_width;
80*4bdc9457SAndroid Build Coastguard Worker   uint32_t channels;
81*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_height;
82*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_width;
83*4bdc9457SAndroid Build Coastguard Worker 
84*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> input_value_dims;
85*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> input_index_dims;
86*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> output_dims;
87*4bdc9457SAndroid Build Coastguard Worker 
88*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input;
89*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input_index;
90*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> operator_output;
91*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> subgraph_output;
92*4bdc9457SAndroid Build Coastguard Worker };
93*4bdc9457SAndroid Build Coastguard Worker 
94*4bdc9457SAndroid Build Coastguard Worker using Unpooling2DTestX32 = Unpooling2DTestBase<uint32_t>;
95*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Unpooling2DTestX32,define)96*4bdc9457SAndroid Build Coastguard Worker TEST_F(Unpooling2DTestX32, define)
97*4bdc9457SAndroid Build Coastguard Worker {
98*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
99*4bdc9457SAndroid Build Coastguard Worker 
100*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
101*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
102*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
103*4bdc9457SAndroid Build Coastguard Worker 
104*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_value_id = XNN_INVALID_NODE_ID;
105*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
106*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
107*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input_value_dims.size(), input_value_dims.data(), nullptr,
108*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_value_id));
109*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_value_id, XNN_INVALID_NODE_ID);
110*4bdc9457SAndroid Build Coastguard Worker 
111*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_index_id = XNN_INVALID_NODE_ID;
112*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
113*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
114*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input_index_dims.size(), input_index_dims.data(),
115*4bdc9457SAndroid Build Coastguard Worker                           input_index.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &input_index_id));
116*4bdc9457SAndroid Build Coastguard Worker 
117*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
118*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
119*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
120*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
121*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
122*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
125*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_unpooling_2d(
126*4bdc9457SAndroid Build Coastguard Worker                           subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height,
127*4bdc9457SAndroid Build Coastguard Worker                           pooling_width, input_value_id, input_index_id, output_id,
128*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/0));
129*4bdc9457SAndroid Build Coastguard Worker 
130*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
131*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
132*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_unpooling_2d);
133*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
134*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.padding_top, padding_top);
135*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.padding_right, padding_right);
136*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.padding_bottom, padding_bottom);
137*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.padding_left, padding_left);
138*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.pooling_height, pooling_height);
139*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.pooling_2d.pooling_width, pooling_width);
140*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
141*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input_value_id);
142*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input_index_id);
143*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
144*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
145*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
146*4bdc9457SAndroid Build Coastguard Worker }
147*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Unpooling2DTestX32,matches_operator_api)148*4bdc9457SAndroid Build Coastguard Worker TEST_F(Unpooling2DTestX32, matches_operator_api)
149*4bdc9457SAndroid Build Coastguard Worker {
150*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
151*4bdc9457SAndroid Build Coastguard Worker 
152*4bdc9457SAndroid Build Coastguard Worker   std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); });
153*4bdc9457SAndroid Build Coastguard Worker   std::generate(input_index.begin(), input_index.end(), [&]() { return index_dist(rng); });
154*4bdc9457SAndroid Build Coastguard Worker   std::generate(operator_output.begin(), operator_output.end(), [&]() { return u32dist(rng); });
155*4bdc9457SAndroid Build Coastguard Worker   std::generate(subgraph_output.begin(), subgraph_output.end(), [&]() { return u32dist(rng); });
156*4bdc9457SAndroid Build Coastguard Worker 
157*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
158*4bdc9457SAndroid Build Coastguard Worker 
159*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
160*4bdc9457SAndroid Build Coastguard Worker   const xnn_status status = xnn_create_unpooling2d_nhwc_x32(
161*4bdc9457SAndroid Build Coastguard Worker     padding_top, padding_right, padding_bottom, padding_left, pooling_height, pooling_width, channels, channels,
162*4bdc9457SAndroid Build Coastguard Worker     channels, /*flags=*/0, &op);
163*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
164*4bdc9457SAndroid Build Coastguard Worker 
165*4bdc9457SAndroid Build Coastguard Worker   if (status == xnn_status_unsupported_hardware) {
166*4bdc9457SAndroid Build Coastguard Worker     GTEST_SKIP();
167*4bdc9457SAndroid Build Coastguard Worker   }
168*4bdc9457SAndroid Build Coastguard Worker 
169*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, status);
170*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, op);
171*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
172*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
173*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_unpooling2d_nhwc_x32(
174*4bdc9457SAndroid Build Coastguard Worker       op, batch_size, input_height, input_width, input.data(), input_index.data(), operator_output.data(),
175*4bdc9457SAndroid Build Coastguard Worker       /*threadpool=*/nullptr));
176*4bdc9457SAndroid Build Coastguard Worker 
177*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
178*4bdc9457SAndroid Build Coastguard Worker 
179*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
180*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
181*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
182*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
183*4bdc9457SAndroid Build Coastguard Worker 
184*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_value_id = XNN_INVALID_NODE_ID;
185*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
186*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
187*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input_value_dims.size(), input_value_dims.data(), nullptr,
188*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_value_id));
189*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_value_id, XNN_INVALID_NODE_ID);
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_index_id = XNN_INVALID_NODE_ID;
192*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
193*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
194*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input_index_dims.size(), input_index_dims.data(),
195*4bdc9457SAndroid Build Coastguard Worker                           input_index.data(), XNN_INVALID_VALUE_ID, /*flags=*/0, &input_index_id));
196*4bdc9457SAndroid Build Coastguard Worker 
197*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
198*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
199*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
200*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
201*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
202*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
203*4bdc9457SAndroid Build Coastguard Worker 
204*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
205*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_unpooling_2d(
206*4bdc9457SAndroid Build Coastguard Worker                           subgraph, padding_top, padding_right, padding_bottom, padding_left, pooling_height,
207*4bdc9457SAndroid Build Coastguard Worker                           pooling_width, input_value_id, input_index_id, output_id,
208*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/0));
209*4bdc9457SAndroid Build Coastguard Worker 
210*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
211*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
213*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
214*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 2> external = {
215*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input_value_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
216*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
217*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
218*4bdc9457SAndroid Build Coastguard Worker 
219*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
220*4bdc9457SAndroid Build Coastguard Worker }
221