xref: /aosp_15_r20/external/XNNPACK/test/static-constant-pad.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>   // For std::generate, std::shuffle.
7*4bdc9457SAndroid Build Coastguard Worker #include <array>       // For std::array.
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>     // For size_t.
9*4bdc9457SAndroid Build Coastguard Worker #include <functional>  // For std::multiplies.
10*4bdc9457SAndroid Build Coastguard Worker #include <memory>      // For std::unique_ptr.
11*4bdc9457SAndroid Build Coastguard Worker #include <random>      // For std::random_device, std::mt19937, std::uniform_real_distribution.
12*4bdc9457SAndroid Build Coastguard Worker #include <vector>      // For std::vector.
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-unary-tester.h"
20*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker using StaticConstantPadTestInt8 = UnaryTest<int8_t>;
24*4bdc9457SAndroid Build Coastguard Worker using StaticConstantPadTestUint8 = UnaryTest<uint8_t>;
25*4bdc9457SAndroid Build Coastguard Worker using StaticConstantPadTestF32 = UnaryTest<float>;
26*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestInt8,define)27*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestInt8, define)
28*4bdc9457SAndroid Build Coastguard Worker {
29*4bdc9457SAndroid Build Coastguard Worker   const int32_t zero_point = i8dist(rng);
30*4bdc9457SAndroid Build Coastguard Worker   const float scale = scale_dist(rng);
31*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
32*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
33*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
34*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
35*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
36*4bdc9457SAndroid Build Coastguard Worker   uint32_t quantized_padding_value = xnn_qs8_quantize(padding_value, scale, zero_point);
37*4bdc9457SAndroid Build Coastguard Worker 
38*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
39*4bdc9457SAndroid Build Coastguard Worker 
40*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
41*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
42*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
43*4bdc9457SAndroid Build Coastguard Worker 
44*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
45*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
46*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
47*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, zero_point, scale, dims.size(), dims.data(), nullptr, 0,
48*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
49*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
50*4bdc9457SAndroid Build Coastguard Worker 
51*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
52*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
53*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
54*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, zero_point, scale, dims.size(), dims.data(), nullptr, 1,
55*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
56*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
57*4bdc9457SAndroid Build Coastguard Worker 
58*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
59*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
60*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
61*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
62*4bdc9457SAndroid Build Coastguard Worker 
63*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
64*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
65*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_static_constant_pad);
66*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
67*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
68*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.pre_paddings[i], pre_paddings[i]);
69*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.post_paddings[i], post_paddings[i]);
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.static_pad.padding_value, quantized_padding_value);
72*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 1);
73*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input_id);
74*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
75*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
76*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
77*4bdc9457SAndroid Build Coastguard Worker }
78*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestUint8,define)79*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestUint8, define)
80*4bdc9457SAndroid Build Coastguard Worker {
81*4bdc9457SAndroid Build Coastguard Worker   const int32_t zero_point = u8dist(rng);
82*4bdc9457SAndroid Build Coastguard Worker   const float scale = scale_dist(rng);
83*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
84*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
85*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
86*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
87*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
88*4bdc9457SAndroid Build Coastguard Worker   uint32_t quantized_padding_value = xnn_qu8_quantize(padding_value, scale, zero_point);
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
91*4bdc9457SAndroid Build Coastguard Worker 
92*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
93*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
94*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
95*4bdc9457SAndroid Build Coastguard Worker 
96*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
97*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
98*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
99*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, zero_point, scale, dims.size(), dims.data(), nullptr, 0,
100*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
101*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
102*4bdc9457SAndroid Build Coastguard Worker 
103*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
104*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
105*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
106*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, zero_point, scale, dims.size(), dims.data(), nullptr, 1,
107*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
108*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
109*4bdc9457SAndroid Build Coastguard Worker 
110*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
111*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
112*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
113*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
116*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
117*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_static_constant_pad);
118*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
119*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
120*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.pre_paddings[i], pre_paddings[i]);
121*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.post_paddings[i], post_paddings[i]);
122*4bdc9457SAndroid Build Coastguard Worker   }
123*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.static_pad.padding_value, quantized_padding_value);
124*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 1);
125*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input_id);
126*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
127*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
128*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
129*4bdc9457SAndroid Build Coastguard Worker }
130*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestF32,define)131*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestF32, define)
132*4bdc9457SAndroid Build Coastguard Worker {
133*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
134*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
135*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
136*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
137*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
138*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_value_as_bits = float_as_uint32(padding_value);
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
141*4bdc9457SAndroid Build Coastguard Worker 
142*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
143*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
144*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
147*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
148*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
149*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 0,
150*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
151*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
152*4bdc9457SAndroid Build Coastguard Worker 
153*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
154*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
155*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
156*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 1,
157*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
158*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
159*4bdc9457SAndroid Build Coastguard Worker 
160*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
161*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
162*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
163*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
164*4bdc9457SAndroid Build Coastguard Worker 
165*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
166*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
167*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_static_constant_pad);
168*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
169*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
170*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.pre_paddings[i], pre_paddings[i]);
171*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(node->params.static_pad.post_paddings[i], post_paddings[i]);
172*4bdc9457SAndroid Build Coastguard Worker   }
173*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.static_pad.padding_value, padding_value_as_bits);
174*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 1);
175*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input_id);
176*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
177*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
178*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
179*4bdc9457SAndroid Build Coastguard Worker }
180*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestInt8,matches_operator_api)181*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestInt8, matches_operator_api)
182*4bdc9457SAndroid Build Coastguard Worker {
183*4bdc9457SAndroid Build Coastguard Worker   const int32_t zero_point = i8dist(rng);
184*4bdc9457SAndroid Build Coastguard Worker   const float scale = scale_dist(rng);
185*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
186*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
187*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
188*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
189*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
190*4bdc9457SAndroid Build Coastguard Worker   uint32_t quantized_padding_value = xnn_qs8_quantize(padding_value, scale, zero_point);
191*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> output_dims = dims;
192*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
193*4bdc9457SAndroid Build Coastguard Worker     output_dims[i] = pre_paddings[i] + output_dims[i] + post_paddings[i];
194*4bdc9457SAndroid Build Coastguard Worker   }
195*4bdc9457SAndroid Build Coastguard Worker   // Output sizes
196*4bdc9457SAndroid Build Coastguard Worker   operator_output = std::vector<int8_t>(NumElements(output_dims));
197*4bdc9457SAndroid Build Coastguard Worker   subgraph_output = std::vector<int8_t>(operator_output.size());
198*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
199*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
200*4bdc9457SAndroid Build Coastguard Worker 
201*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
202*4bdc9457SAndroid Build Coastguard Worker 
203*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
204*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
205*4bdc9457SAndroid Build Coastguard Worker   const xnn_status status = xnn_create_constant_pad_nd_x8(&quantized_padding_value, /*flags=*/0, &op);
206*4bdc9457SAndroid Build Coastguard Worker   if (status == xnn_status_unsupported_hardware) {
207*4bdc9457SAndroid Build Coastguard Worker     GTEST_SKIP();
208*4bdc9457SAndroid Build Coastguard Worker   }
209*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, status);
210*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, op);
211*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
213*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_constant_pad_nd_x8(
214*4bdc9457SAndroid Build Coastguard Worker                           op, dims.size(), dims.data(), pre_paddings.data(), post_paddings.data(), input.data(),
215*4bdc9457SAndroid Build Coastguard Worker                           operator_output.data(), /*threadpool=*/nullptr));
216*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
217*4bdc9457SAndroid Build Coastguard Worker 
218*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
219*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
220*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
221*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
222*4bdc9457SAndroid Build Coastguard Worker 
223*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
224*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
225*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
226*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, zero_point, scale, dims.size(), dims.data(), nullptr, 0,
227*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
228*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
229*4bdc9457SAndroid Build Coastguard Worker 
230*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
231*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
232*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
233*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
234*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 1,
235*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
236*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
237*4bdc9457SAndroid Build Coastguard Worker 
238*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
239*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
240*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
241*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
244*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
245*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
246*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
247*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 2> external = {
248*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
249*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
250*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
253*4bdc9457SAndroid Build Coastguard Worker }
254*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestUint8,matches_operator_api)255*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestUint8, matches_operator_api)
256*4bdc9457SAndroid Build Coastguard Worker {
257*4bdc9457SAndroid Build Coastguard Worker   const int32_t zero_point = u8dist(rng);
258*4bdc9457SAndroid Build Coastguard Worker   const float scale = scale_dist(rng);
259*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
260*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
261*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
262*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
263*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
264*4bdc9457SAndroid Build Coastguard Worker   uint32_t quantized_padding_value = xnn_qu8_quantize(padding_value, scale, zero_point);
265*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> output_dims = dims;
266*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
267*4bdc9457SAndroid Build Coastguard Worker     output_dims[i] = pre_paddings[i] + output_dims[i] + post_paddings[i];
268*4bdc9457SAndroid Build Coastguard Worker   }
269*4bdc9457SAndroid Build Coastguard Worker   // Output sizes
270*4bdc9457SAndroid Build Coastguard Worker   operator_output = std::vector<uint8_t>(NumElements(output_dims));
271*4bdc9457SAndroid Build Coastguard Worker   subgraph_output = std::vector<uint8_t>(operator_output.size());
272*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
273*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
274*4bdc9457SAndroid Build Coastguard Worker 
275*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
276*4bdc9457SAndroid Build Coastguard Worker 
277*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
278*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
279*4bdc9457SAndroid Build Coastguard Worker   const xnn_status status = xnn_create_constant_pad_nd_x8(&quantized_padding_value, /*flags=*/0, &op);
280*4bdc9457SAndroid Build Coastguard Worker   if (status == xnn_status_unsupported_hardware) {
281*4bdc9457SAndroid Build Coastguard Worker     GTEST_SKIP();
282*4bdc9457SAndroid Build Coastguard Worker   }
283*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, status);
284*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, op);
285*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
286*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
287*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_constant_pad_nd_x8(
288*4bdc9457SAndroid Build Coastguard Worker                           op, dims.size(), dims.data(), pre_paddings.data(), post_paddings.data(), input.data(),
289*4bdc9457SAndroid Build Coastguard Worker                           operator_output.data(), /*threadpool=*/nullptr));
290*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
291*4bdc9457SAndroid Build Coastguard Worker 
292*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
293*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
294*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
295*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
296*4bdc9457SAndroid Build Coastguard Worker 
297*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
298*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
299*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
300*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, zero_point, scale, dims.size(), dims.data(), nullptr, 0,
301*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
302*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
303*4bdc9457SAndroid Build Coastguard Worker 
304*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
305*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
306*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
307*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
308*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 1,
309*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
310*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
311*4bdc9457SAndroid Build Coastguard Worker 
312*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
313*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
314*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
315*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
316*4bdc9457SAndroid Build Coastguard Worker 
317*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
318*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
319*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
320*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
321*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 2> external = {
322*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
323*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
324*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
325*4bdc9457SAndroid Build Coastguard Worker 
326*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
327*4bdc9457SAndroid Build Coastguard Worker }
328*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(StaticConstantPadTestF32,matches_operator_api)329*4bdc9457SAndroid Build Coastguard Worker TEST_F(StaticConstantPadTestF32, matches_operator_api)
330*4bdc9457SAndroid Build Coastguard Worker {
331*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings;
332*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings;
333*4bdc9457SAndroid Build Coastguard Worker   std::fill(pre_paddings.begin(), pre_paddings.begin() + dims.size(), dim_dist(rng));
334*4bdc9457SAndroid Build Coastguard Worker   std::fill(post_paddings.begin(), post_paddings.begin() + dims.size(), dim_dist(rng));
335*4bdc9457SAndroid Build Coastguard Worker   float padding_value = f32dist(rng);
336*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_value_as_u32 = float_as_uint32(padding_value);
337*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> output_dims = dims;
338*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < dims.size(); i++) {
339*4bdc9457SAndroid Build Coastguard Worker     output_dims[i] = pre_paddings[i] + output_dims[i] + post_paddings[i];
340*4bdc9457SAndroid Build Coastguard Worker   }
341*4bdc9457SAndroid Build Coastguard Worker   // Output sizes
342*4bdc9457SAndroid Build Coastguard Worker   operator_output = std::vector<float>(NumElements(output_dims));
343*4bdc9457SAndroid Build Coastguard Worker   subgraph_output = std::vector<float>(operator_output.size());
344*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), UINT32_C(0xDEADBEEF));
345*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT32_C(0xDEADBEEF));
346*4bdc9457SAndroid Build Coastguard Worker 
347*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
350*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
351*4bdc9457SAndroid Build Coastguard Worker   const xnn_status status = xnn_create_constant_pad_nd_x32(&padding_value_as_u32, /*flags=*/0, &op);
352*4bdc9457SAndroid Build Coastguard Worker   if (status == xnn_status_unsupported_hardware) {
353*4bdc9457SAndroid Build Coastguard Worker     GTEST_SKIP();
354*4bdc9457SAndroid Build Coastguard Worker   }
355*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, status);
356*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, op);
357*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
358*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
359*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_constant_pad_nd_x32(
360*4bdc9457SAndroid Build Coastguard Worker                           op, dims.size(), dims.data(), pre_paddings.data(), post_paddings.data(), input.data(),
361*4bdc9457SAndroid Build Coastguard Worker                           operator_output.data(), /*threadpool=*/nullptr));
362*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
363*4bdc9457SAndroid Build Coastguard Worker 
364*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
365*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
366*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
367*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
368*4bdc9457SAndroid Build Coastguard Worker 
369*4bdc9457SAndroid Build Coastguard Worker   input_id = XNN_INVALID_NODE_ID;
370*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
371*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
372*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 0,
373*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
374*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
375*4bdc9457SAndroid Build Coastguard Worker 
376*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
377*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
378*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
379*4bdc9457SAndroid Build Coastguard Worker     xnn_define_tensor_value(
380*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 1,
381*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
382*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
383*4bdc9457SAndroid Build Coastguard Worker 
384*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
385*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
386*4bdc9457SAndroid Build Coastguard Worker     xnn_define_static_constant_pad(
387*4bdc9457SAndroid Build Coastguard Worker       subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0));
388*4bdc9457SAndroid Build Coastguard Worker 
389*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
390*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
391*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
392*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
393*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 2> external = {
394*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
395*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
396*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
397*4bdc9457SAndroid Build Coastguard Worker 
398*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
399*4bdc9457SAndroid Build Coastguard Worker }
400