1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> // For std::generate, std::min.
7*4bdc9457SAndroid Build Coastguard Worker #include <array> // For std::array.
8*4bdc9457SAndroid Build Coastguard Worker #include <cmath> // For std::lrintf.
9*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> // For size_t.
10*4bdc9457SAndroid Build Coastguard Worker #include <cstdint> // For uint32_t.
11*4bdc9457SAndroid Build Coastguard Worker #include <limits> // For std::numeric_limits.
12*4bdc9457SAndroid Build Coastguard Worker #include <memory> // For std::unique_ptr.
13*4bdc9457SAndroid Build Coastguard Worker #include <numeric> // For std::accumulate.
14*4bdc9457SAndroid Build Coastguard Worker #include <random> // For std::random_device, std::mt19937, std::uniform_real_distribution.
15*4bdc9457SAndroid Build Coastguard Worker #include <vector> // For std::vector.
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
21*4bdc9457SAndroid Build Coastguard Worker
22*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
23*4bdc9457SAndroid Build Coastguard Worker
24*4bdc9457SAndroid Build Coastguard Worker template <class T, class BiasType = T> class FullyConnectedTestBase : public ::testing::Test {
25*4bdc9457SAndroid Build Coastguard Worker protected:
FullyConnectedTestBase()26*4bdc9457SAndroid Build Coastguard Worker FullyConnectedTestBase()
27*4bdc9457SAndroid Build Coastguard Worker {
28*4bdc9457SAndroid Build Coastguard Worker random_device = std::unique_ptr<std::random_device>(new std::random_device());
29*4bdc9457SAndroid Build Coastguard Worker rng = std::mt19937((*random_device)());
30*4bdc9457SAndroid Build Coastguard Worker input_size_dist = std::uniform_int_distribution<uint32_t>(10, 15);
31*4bdc9457SAndroid Build Coastguard Worker kernel_size_dist = std::uniform_int_distribution<uint32_t>(1, 5);
32*4bdc9457SAndroid Build Coastguard Worker stride_dist = std::uniform_int_distribution<uint32_t>(1, 2);
33*4bdc9457SAndroid Build Coastguard Worker f32dist = std::uniform_real_distribution<float>(0.1f, 1.0f);
34*4bdc9457SAndroid Build Coastguard Worker scale_dist = std::uniform_real_distribution<float>(1.0f, 5.0f);
35*4bdc9457SAndroid Build Coastguard Worker i32dist = std::uniform_int_distribution<int32_t>(-10000, 10000);
36*4bdc9457SAndroid Build Coastguard Worker auto shape_dist = std::uniform_int_distribution<size_t>(2, XNN_MAX_TENSOR_DIMS);
37*4bdc9457SAndroid Build Coastguard Worker dim_dist = std::uniform_int_distribution<size_t>(5, 15);
38*4bdc9457SAndroid Build Coastguard Worker i8dist =
39*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
40*4bdc9457SAndroid Build Coastguard Worker w8dist =
41*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<uint8_t>::max(), std::numeric_limits<uint8_t>::max());
42*4bdc9457SAndroid Build Coastguard Worker
43*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity();
44*4bdc9457SAndroid Build Coastguard Worker output_max = std::numeric_limits<float>::infinity();
45*4bdc9457SAndroid Build Coastguard Worker
46*4bdc9457SAndroid Build Coastguard Worker size_t num_input_dims = shape_dist(rng);
47*4bdc9457SAndroid Build Coastguard Worker input_dims = RandomShape(num_input_dims);
48*4bdc9457SAndroid Build Coastguard Worker assert(input_dims.size() >= 2);
49*4bdc9457SAndroid Build Coastguard Worker output_channels = dim_dist(rng);
50*4bdc9457SAndroid Build Coastguard Worker input_channels = input_dims.back();
51*4bdc9457SAndroid Build Coastguard Worker kernel_dims = {output_channels, input_channels};
52*4bdc9457SAndroid Build Coastguard Worker output_dims = input_dims;
53*4bdc9457SAndroid Build Coastguard Worker output_dims[output_dims.size() - 1] = output_channels;
54*4bdc9457SAndroid Build Coastguard Worker
55*4bdc9457SAndroid Build Coastguard Worker batch_size = NumElements(input_dims) / input_channels;
56*4bdc9457SAndroid Build Coastguard Worker
57*4bdc9457SAndroid Build Coastguard Worker input = std::vector<T>(XNN_EXTRA_BYTES / sizeof(T) + NumElements(input_dims));
58*4bdc9457SAndroid Build Coastguard Worker kernel = std::vector<T>(input_channels * output_channels);
59*4bdc9457SAndroid Build Coastguard Worker bias = std::vector<BiasType>(output_channels);
60*4bdc9457SAndroid Build Coastguard Worker operator_output = std::vector<T>(NumElements(output_dims));
61*4bdc9457SAndroid Build Coastguard Worker subgraph_output = std::vector<T>(operator_output.size());
62*4bdc9457SAndroid Build Coastguard Worker accumulators = std::vector<int32_t>(batch_size * output_channels);
63*4bdc9457SAndroid Build Coastguard Worker }
64*4bdc9457SAndroid Build Coastguard Worker
RandomShape(size_t num_dims)65*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> RandomShape(size_t num_dims)
66*4bdc9457SAndroid Build Coastguard Worker {
67*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> dims(num_dims);
68*4bdc9457SAndroid Build Coastguard Worker std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
69*4bdc9457SAndroid Build Coastguard Worker return dims;
70*4bdc9457SAndroid Build Coastguard Worker }
71*4bdc9457SAndroid Build Coastguard Worker
NumElements(std::vector<size_t> & dims)72*4bdc9457SAndroid Build Coastguard Worker size_t NumElements(std::vector<size_t>& dims)
73*4bdc9457SAndroid Build Coastguard Worker {
74*4bdc9457SAndroid Build Coastguard Worker return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
75*4bdc9457SAndroid Build Coastguard Worker }
76*4bdc9457SAndroid Build Coastguard Worker
77*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<std::random_device> random_device;
78*4bdc9457SAndroid Build Coastguard Worker std::mt19937 rng;
79*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> input_size_dist;
80*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> kernel_size_dist;
81*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> stride_dist;
82*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist;
83*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
84*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> scale_dist;
85*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<size_t> dim_dist;
86*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist;
87*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist;
88*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist;
89*4bdc9457SAndroid Build Coastguard Worker
90*4bdc9457SAndroid Build Coastguard Worker uint32_t batch_size;
91*4bdc9457SAndroid Build Coastguard Worker size_t input_channels;
92*4bdc9457SAndroid Build Coastguard Worker size_t output_channels;
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker float output_min;
95*4bdc9457SAndroid Build Coastguard Worker float output_max;
96*4bdc9457SAndroid Build Coastguard Worker
97*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> input_dims;
98*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> kernel_dims;
99*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> bias_dims;
100*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> output_dims;
101*4bdc9457SAndroid Build Coastguard Worker
102*4bdc9457SAndroid Build Coastguard Worker std::vector<T> input;
103*4bdc9457SAndroid Build Coastguard Worker std::vector<T> kernel;
104*4bdc9457SAndroid Build Coastguard Worker std::vector<BiasType> bias;
105*4bdc9457SAndroid Build Coastguard Worker std::vector<T> operator_output;
106*4bdc9457SAndroid Build Coastguard Worker std::vector<T> subgraph_output;
107*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators;
108*4bdc9457SAndroid Build Coastguard Worker };
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker template <class T> class QuantizedFullyConnectedTestBase : public FullyConnectedTestBase<T, int32_t> {
111*4bdc9457SAndroid Build Coastguard Worker protected:
initialize_accumulators_from_bias()112*4bdc9457SAndroid Build Coastguard Worker void initialize_accumulators_from_bias()
113*4bdc9457SAndroid Build Coastguard Worker {
114*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < this->batch_size; i++) {
115*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < this->output_channels; oc++) {
116*4bdc9457SAndroid Build Coastguard Worker this->accumulators[i * this->output_channels + oc] = this->bias[oc];
117*4bdc9457SAndroid Build Coastguard Worker }
118*4bdc9457SAndroid Build Coastguard Worker }
119*4bdc9457SAndroid Build Coastguard Worker }
120*4bdc9457SAndroid Build Coastguard Worker };
121*4bdc9457SAndroid Build Coastguard Worker
122*4bdc9457SAndroid Build Coastguard Worker using FullyConnectedTestQS8 = QuantizedFullyConnectedTestBase<int8_t>;
123*4bdc9457SAndroid Build Coastguard Worker using FullyConnectedTestQU8 = QuantizedFullyConnectedTestBase<uint8_t>;
124*4bdc9457SAndroid Build Coastguard Worker using FullyConnectedTestF32 = FullyConnectedTestBase<float>;
125*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestQS8,define)126*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestQS8, define)
127*4bdc9457SAndroid Build Coastguard Worker {
128*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
129*4bdc9457SAndroid Build Coastguard Worker
130*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
131*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
132*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
133*4bdc9457SAndroid Build Coastguard Worker
134*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
135*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
136*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
137*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr,
138*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/0, /*flags=*/0, &input_id));
139*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
140*4bdc9457SAndroid Build Coastguard Worker
141*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
142*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
143*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
144*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, 0, 1.0f, kernel_dims.size(), kernel_dims.data(), kernel.data(),
145*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/1, /*flags=*/0, &kernel_id));
146*4bdc9457SAndroid Build Coastguard Worker
147*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
148*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
149*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
150*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), bias_dims.data(), bias.data(),
151*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/2, /*flags=*/0, &bias_id));
152*4bdc9457SAndroid Build Coastguard Worker
153*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
154*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
155*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
156*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr,
157*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/3, /*flags=*/0, &output_id));
158*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
159*4bdc9457SAndroid Build Coastguard Worker
160*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
161*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
162*4bdc9457SAndroid Build Coastguard Worker xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph->num_nodes, 1);
165*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = &subgraph->nodes[0];
166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->type, xnn_node_type_fully_connected);
167*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
168*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_min, output_min);
169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_max, output_max);
170*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_inputs, 3);
171*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[0], input_id);
172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[1], kernel_id);
173*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[2], bias_id);
174*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_outputs, 1);
175*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->outputs[0], output_id);
176*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->flags, 0);
177*4bdc9457SAndroid Build Coastguard Worker }
178*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestQU8,define)179*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestQU8, define)
180*4bdc9457SAndroid Build Coastguard Worker {
181*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
182*4bdc9457SAndroid Build Coastguard Worker
183*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
184*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
185*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
186*4bdc9457SAndroid Build Coastguard Worker
187*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
188*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
189*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
190*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr,
191*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/0, /*flags=*/0, &input_id));
192*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
193*4bdc9457SAndroid Build Coastguard Worker
194*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
195*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
196*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
197*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, 0, 1.0f, kernel_dims.size(), kernel_dims.data(), kernel.data(),
198*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/1, /*flags=*/0, &kernel_id));
199*4bdc9457SAndroid Build Coastguard Worker
200*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
201*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
202*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
203*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), bias_dims.data(), bias.data(),
204*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/2, /*flags=*/0, &bias_id));
205*4bdc9457SAndroid Build Coastguard Worker
206*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
207*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
208*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
209*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr,
210*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/3, /*flags=*/0, &output_id));
211*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
212*4bdc9457SAndroid Build Coastguard Worker
213*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
214*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_fully_connected(
215*4bdc9457SAndroid Build Coastguard Worker subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id,
216*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0));
217*4bdc9457SAndroid Build Coastguard Worker
218*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph->num_nodes, 1);
219*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = &subgraph->nodes[0];
220*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->type, xnn_node_type_fully_connected);
221*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
222*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_min, output_min);
223*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_max, output_max);
224*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_inputs, 3);
225*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[0], input_id);
226*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[1], kernel_id);
227*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[2], bias_id);
228*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_outputs, 1);
229*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->outputs[0], output_id);
230*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->flags, 0);
231*4bdc9457SAndroid Build Coastguard Worker }
232*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestF32,define)233*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestF32, define)
234*4bdc9457SAndroid Build Coastguard Worker {
235*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
236*4bdc9457SAndroid Build Coastguard Worker
237*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
238*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
239*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
240*4bdc9457SAndroid Build Coastguard Worker
241*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
242*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
243*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
244*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr,
245*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/0, /*flags=*/0, &input_id));
246*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
247*4bdc9457SAndroid Build Coastguard Worker
248*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
249*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
250*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
251*4bdc9457SAndroid Build Coastguard Worker xnn_define_tensor_value(
252*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), /*external_id=*/1,
253*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &kernel_id));
254*4bdc9457SAndroid Build Coastguard Worker
255*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
256*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
257*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
258*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(),
259*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/2, /*flags=*/0, &bias_id));
260*4bdc9457SAndroid Build Coastguard Worker
261*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
262*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
263*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
264*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
265*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/3, /*flags=*/0, &output_id));
266*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
267*4bdc9457SAndroid Build Coastguard Worker
268*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
269*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
270*4bdc9457SAndroid Build Coastguard Worker xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));
271*4bdc9457SAndroid Build Coastguard Worker
272*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph->num_nodes, 1);
273*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = &subgraph->nodes[0];
274*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->type, xnn_node_type_fully_connected);
275*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
276*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_min, output_min);
277*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->activation.output_max, output_max);
278*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_inputs, 3);
279*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[0], input_id);
280*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[1], kernel_id);
281*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->inputs[2], bias_id);
282*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->num_outputs, 1);
283*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->outputs[0], output_id);
284*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(node->flags, 0);
285*4bdc9457SAndroid Build Coastguard Worker }
286*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestQS8,matches_operator_api)287*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestQS8, matches_operator_api)
288*4bdc9457SAndroid Build Coastguard Worker {
289*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
290*4bdc9457SAndroid Build Coastguard Worker
291*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t op = nullptr;
292*4bdc9457SAndroid Build Coastguard Worker
293*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
294*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
295*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
296*4bdc9457SAndroid Build Coastguard Worker std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
297*4bdc9457SAndroid Build Coastguard Worker std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
298*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = -1;
299*4bdc9457SAndroid Build Coastguard Worker const float input_scale = scale_dist(rng);
300*4bdc9457SAndroid Build Coastguard Worker const float kernel_scale = scale_dist(rng);
301*4bdc9457SAndroid Build Coastguard Worker
302*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
303*4bdc9457SAndroid Build Coastguard Worker initialize_accumulators_from_bias();
304*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size; i++) {
305*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels; oc++) {
306*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels; ic++) {
307*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels + oc] +=
308*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_channels + ic]) - int32_t(input_zero_point)) *
309*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[oc * input_channels + ic]);
310*4bdc9457SAndroid Build Coastguard Worker }
311*4bdc9457SAndroid Build Coastguard Worker }
312*4bdc9457SAndroid Build Coastguard Worker }
313*4bdc9457SAndroid Build Coastguard Worker
314*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
315*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
316*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
317*4bdc9457SAndroid Build Coastguard Worker
318*4bdc9457SAndroid Build Coastguard Worker float output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
319*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point = int8_t(std::max(
320*4bdc9457SAndroid Build Coastguard Worker std::min(
321*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
322*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())),
323*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::min())));
324*4bdc9457SAndroid Build Coastguard Worker const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point);
325*4bdc9457SAndroid Build Coastguard Worker const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point);
326*4bdc9457SAndroid Build Coastguard Worker
327*4bdc9457SAndroid Build Coastguard Worker // Call operator API.
328*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_qs8(
329*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels, input_channels, output_channels, input_zero_point, input_scale, kernel_scale,
330*4bdc9457SAndroid Build Coastguard Worker kernel.data(), bias.data(), output_zero_point, output_scale, quantized_output_min, quantized_output_max,
331*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, nullptr, &op);
332*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
333*4bdc9457SAndroid Build Coastguard Worker
334*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) {
335*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
336*4bdc9457SAndroid Build Coastguard Worker }
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status);
339*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, op);
340*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
341*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_setup_fully_connected_nc_qs8(
342*4bdc9457SAndroid Build Coastguard Worker op, batch_size, input.data(), operator_output.data(),
343*4bdc9457SAndroid Build Coastguard Worker /*threadpool=*/nullptr));
344*4bdc9457SAndroid Build Coastguard Worker
345*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
346*4bdc9457SAndroid Build Coastguard Worker
347*4bdc9457SAndroid Build Coastguard Worker // Call subgraph API.
348*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
349*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
350*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
351*4bdc9457SAndroid Build Coastguard Worker
352*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
353*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
354*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
355*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, input_zero_point, input_scale, input_dims.size(),
356*4bdc9457SAndroid Build Coastguard Worker input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
357*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
358*4bdc9457SAndroid Build Coastguard Worker
359*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
360*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
361*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
362*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, 0, kernel_scale, kernel_dims.size(), kernel_dims.data(),
363*4bdc9457SAndroid Build Coastguard Worker kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id));
364*4bdc9457SAndroid Build Coastguard Worker
365*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
366*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
367*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
368*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint32, 0, kernel_scale, bias_dims.size(), bias_dims.data(),
369*4bdc9457SAndroid Build Coastguard Worker bias.data(), /*external_id=*/2, /*flags=*/0, &bias_id));
370*4bdc9457SAndroid Build Coastguard Worker
371*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
372*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
373*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
374*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(),
375*4bdc9457SAndroid Build Coastguard Worker output_dims.data(), nullptr, /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
376*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
377*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
378*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
379*4bdc9457SAndroid Build Coastguard Worker xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));
380*4bdc9457SAndroid Build Coastguard Worker
381*4bdc9457SAndroid Build Coastguard Worker xnn_runtime_t runtime = nullptr;
382*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
383*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, runtime);
384*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
385*4bdc9457SAndroid Build Coastguard Worker std::array<xnn_external_value, 2> external = {
386*4bdc9457SAndroid Build Coastguard Worker xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
387*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
388*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
389*4bdc9457SAndroid Build Coastguard Worker
390*4bdc9457SAndroid Build Coastguard Worker // Check outputs match.
391*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < operator_output.size(); i++) {
392*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph_output[i], operator_output[i]);
393*4bdc9457SAndroid Build Coastguard Worker }
394*4bdc9457SAndroid Build Coastguard Worker }
395*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestQU8,matches_operator_api)396*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestQU8, matches_operator_api)
397*4bdc9457SAndroid Build Coastguard Worker {
398*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
399*4bdc9457SAndroid Build Coastguard Worker
400*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t op = nullptr;
401*4bdc9457SAndroid Build Coastguard Worker
402*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
403*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
404*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
405*4bdc9457SAndroid Build Coastguard Worker std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
406*4bdc9457SAndroid Build Coastguard Worker std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
407*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = u8dist(rng);
408*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 0;
409*4bdc9457SAndroid Build Coastguard Worker const float input_scale = scale_dist(rng);
410*4bdc9457SAndroid Build Coastguard Worker const float kernel_scale = scale_dist(rng);
411*4bdc9457SAndroid Build Coastguard Worker
412*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
413*4bdc9457SAndroid Build Coastguard Worker initialize_accumulators_from_bias();
414*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size; i++) {
415*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels; oc++) {
416*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels; ic++) {
417*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels + oc] +=
418*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_channels + ic]) - int32_t(input_zero_point)) *
419*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[oc * input_channels + ic]) - int32_t(kernel_zero_point));
420*4bdc9457SAndroid Build Coastguard Worker }
421*4bdc9457SAndroid Build Coastguard Worker }
422*4bdc9457SAndroid Build Coastguard Worker }
423*4bdc9457SAndroid Build Coastguard Worker
424*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
425*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
426*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
427*4bdc9457SAndroid Build Coastguard Worker
428*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
429*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(
430*4bdc9457SAndroid Build Coastguard Worker std::min(
431*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
432*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())),
433*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::min())));
434*4bdc9457SAndroid Build Coastguard Worker const uint8_t quantized_output_min = xnn_qu8_quantize(output_min, output_scale, output_zero_point);
435*4bdc9457SAndroid Build Coastguard Worker const uint8_t quantized_output_max = xnn_qu8_quantize(output_max, output_scale, output_zero_point);
436*4bdc9457SAndroid Build Coastguard Worker
437*4bdc9457SAndroid Build Coastguard Worker // Call operator API.
438*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_qu8(
439*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels, input_channels, output_channels, input_zero_point, input_scale, kernel_zero_point,
440*4bdc9457SAndroid Build Coastguard Worker kernel_scale, kernel.data(), bias.data(), output_zero_point, output_scale, quantized_output_min,
441*4bdc9457SAndroid Build Coastguard Worker quantized_output_max, /*flags=*/0, nullptr, &op);
442*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
443*4bdc9457SAndroid Build Coastguard Worker
444*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) {
445*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
446*4bdc9457SAndroid Build Coastguard Worker }
447*4bdc9457SAndroid Build Coastguard Worker
448*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status);
449*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, op);
450*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
451*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_setup_fully_connected_nc_qu8(
452*4bdc9457SAndroid Build Coastguard Worker op, batch_size, input.data(), operator_output.data(),
453*4bdc9457SAndroid Build Coastguard Worker /*threadpool=*/nullptr));
454*4bdc9457SAndroid Build Coastguard Worker
455*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
456*4bdc9457SAndroid Build Coastguard Worker
457*4bdc9457SAndroid Build Coastguard Worker // Call subgraph API.
458*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
459*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
460*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
461*4bdc9457SAndroid Build Coastguard Worker
462*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
463*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
464*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
465*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, input_zero_point, input_scale, input_dims.size(),
466*4bdc9457SAndroid Build Coastguard Worker input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
467*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
468*4bdc9457SAndroid Build Coastguard Worker
469*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
470*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
471*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
472*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, 0, kernel_scale, kernel_dims.size(), kernel_dims.data(),
473*4bdc9457SAndroid Build Coastguard Worker kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id));
474*4bdc9457SAndroid Build Coastguard Worker
475*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
476*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
477*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
478*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_qint32, 0, kernel_scale, bias_dims.size(), bias_dims.data(),
479*4bdc9457SAndroid Build Coastguard Worker bias.data(), /*external_id=*/2, /*flags=*/0, &bias_id));
480*4bdc9457SAndroid Build Coastguard Worker
481*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
482*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
483*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_quantized_tensor_value(
484*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(),
485*4bdc9457SAndroid Build Coastguard Worker output_dims.data(), nullptr, /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
486*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
487*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
488*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
489*4bdc9457SAndroid Build Coastguard Worker xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));
490*4bdc9457SAndroid Build Coastguard Worker
491*4bdc9457SAndroid Build Coastguard Worker xnn_runtime_t runtime = nullptr;
492*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
493*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, runtime);
494*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
495*4bdc9457SAndroid Build Coastguard Worker std::array<xnn_external_value, 2> external = {
496*4bdc9457SAndroid Build Coastguard Worker xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
497*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
498*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
499*4bdc9457SAndroid Build Coastguard Worker
500*4bdc9457SAndroid Build Coastguard Worker // Check outputs match.
501*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < operator_output.size(); i++) {
502*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph_output[i], operator_output[i]);
503*4bdc9457SAndroid Build Coastguard Worker }
504*4bdc9457SAndroid Build Coastguard Worker }
505*4bdc9457SAndroid Build Coastguard Worker
TEST_F(FullyConnectedTestF32,matches_operator_api)506*4bdc9457SAndroid Build Coastguard Worker TEST_F(FullyConnectedTestF32, matches_operator_api)
507*4bdc9457SAndroid Build Coastguard Worker {
508*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
509*4bdc9457SAndroid Build Coastguard Worker
510*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t op = nullptr;
511*4bdc9457SAndroid Build Coastguard Worker
512*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
513*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
514*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
515*4bdc9457SAndroid Build Coastguard Worker std::fill(operator_output.begin(), operator_output.end(), nanf(""));
516*4bdc9457SAndroid Build Coastguard Worker std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));
517*4bdc9457SAndroid Build Coastguard Worker
518*4bdc9457SAndroid Build Coastguard Worker // Call operator API.
519*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_f32(
520*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels, input_channels, output_channels, kernel.data(), bias.data(), output_min,
521*4bdc9457SAndroid Build Coastguard Worker output_max,
522*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, nullptr, &op);
523*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
524*4bdc9457SAndroid Build Coastguard Worker
525*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) {
526*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
527*4bdc9457SAndroid Build Coastguard Worker }
528*4bdc9457SAndroid Build Coastguard Worker
529*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status);
530*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, op);
531*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
532*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_setup_fully_connected_nc_f32(
533*4bdc9457SAndroid Build Coastguard Worker op, batch_size, input.data(), operator_output.data(),
534*4bdc9457SAndroid Build Coastguard Worker /*threadpool=*/nullptr));
535*4bdc9457SAndroid Build Coastguard Worker
536*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr));
537*4bdc9457SAndroid Build Coastguard Worker
538*4bdc9457SAndroid Build Coastguard Worker // Call subgraph API.
539*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph = nullptr;
540*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
541*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
542*4bdc9457SAndroid Build Coastguard Worker
543*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = XNN_INVALID_NODE_ID;
544*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
545*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
546*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr,
547*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
548*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(input_id, XNN_INVALID_NODE_ID);
549*4bdc9457SAndroid Build Coastguard Worker
550*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_id = XNN_INVALID_NODE_ID;
551*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
552*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
553*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(),
554*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/1, /*flags=*/0, &kernel_id));
555*4bdc9457SAndroid Build Coastguard Worker
556*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = XNN_INVALID_NODE_ID;
557*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
558*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
559*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(),
560*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/2, /*flags=*/0, &bias_id));
561*4bdc9457SAndroid Build Coastguard Worker
562*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = XNN_INVALID_NODE_ID;
563*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
564*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, xnn_define_tensor_value(
565*4bdc9457SAndroid Build Coastguard Worker subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
566*4bdc9457SAndroid Build Coastguard Worker /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
567*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
568*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
569*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
570*4bdc9457SAndroid Build Coastguard Worker xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));
571*4bdc9457SAndroid Build Coastguard Worker
572*4bdc9457SAndroid Build Coastguard Worker xnn_runtime_t runtime = nullptr;
573*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
574*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, runtime);
575*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
576*4bdc9457SAndroid Build Coastguard Worker std::array<xnn_external_value, 2> external = {
577*4bdc9457SAndroid Build Coastguard Worker xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
578*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
579*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
580*4bdc9457SAndroid Build Coastguard Worker
581*4bdc9457SAndroid Build Coastguard Worker // Check outputs match.
582*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < operator_output.size(); i++) {
583*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(subgraph_output[i], operator_output[i]);
584*4bdc9457SAndroid Build Coastguard Worker }
585*4bdc9457SAndroid Build Coastguard Worker }
586