xref: /aosp_15_r20/external/XNNPACK/test/subtract2.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
10*4bdc9457SAndroid Build Coastguard Worker #include <array>
11*4bdc9457SAndroid Build Coastguard Worker #include <memory>
12*4bdc9457SAndroid Build Coastguard Worker #include <vector>
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-binary-tester.h"
20*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker using Subtract2TestQS8 = BinaryTest<int8_t>;
23*4bdc9457SAndroid Build Coastguard Worker using Subtract2TestQU8 = BinaryTest<uint8_t>;
24*4bdc9457SAndroid Build Coastguard Worker using Subtract2TestF32 = BinaryTest<float>;
25*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestQS8,define)26*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestQS8, define)
27*4bdc9457SAndroid Build Coastguard Worker {
28*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = i8dist(rng);
29*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
30*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = i8dist(rng);
31*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
32*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = i8dist(rng);
33*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
34*4bdc9457SAndroid Build Coastguard Worker 
35*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
38*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
39*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
40*4bdc9457SAndroid Build Coastguard Worker 
41*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
42*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
43*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
44*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
45*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
46*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/0, &input1_id));
47*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
48*4bdc9457SAndroid Build Coastguard Worker 
49*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
50*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
51*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
52*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
53*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
54*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/0, &input2_id));
55*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
56*4bdc9457SAndroid Build Coastguard Worker 
57*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
58*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
59*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
60*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(),
61*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
62*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
63*4bdc9457SAndroid Build Coastguard Worker 
64*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
65*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
66*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
67*4bdc9457SAndroid Build Coastguard Worker 
68*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
69*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
70*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_subtract);
71*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
72*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
73*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
74*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
75*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
76*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
77*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
78*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
79*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
80*4bdc9457SAndroid Build Coastguard Worker }
81*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestQU8,define)82*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestQU8, define)
83*4bdc9457SAndroid Build Coastguard Worker {
84*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = u8dist(rng);
85*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
86*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = u8dist(rng);
87*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
88*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = u8dist(rng);
89*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
90*4bdc9457SAndroid Build Coastguard Worker 
91*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
92*4bdc9457SAndroid Build Coastguard Worker 
93*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
94*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
95*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
96*4bdc9457SAndroid Build Coastguard Worker 
97*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
98*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
99*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
100*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
101*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
102*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/0, &input1_id));
103*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
104*4bdc9457SAndroid Build Coastguard Worker 
105*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
106*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
107*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
108*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
109*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
110*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/0, &input2_id));
111*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
112*4bdc9457SAndroid Build Coastguard Worker 
113*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
114*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
115*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
116*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(),
117*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
118*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
119*4bdc9457SAndroid Build Coastguard Worker 
120*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
121*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
122*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
125*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
126*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_subtract);
127*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
128*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
129*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
130*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
131*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
132*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
133*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
134*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
135*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
136*4bdc9457SAndroid Build Coastguard Worker }
137*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestF32,define)138*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestF32, define)
139*4bdc9457SAndroid Build Coastguard Worker {
140*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
141*4bdc9457SAndroid Build Coastguard Worker 
142*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
143*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
144*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
147*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
148*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
149*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr,
150*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input1_id));
151*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
152*4bdc9457SAndroid Build Coastguard Worker 
153*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
154*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
155*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
156*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr,
157*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input2_id));
158*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
159*4bdc9457SAndroid Build Coastguard Worker 
160*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
161*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
162*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
163*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
164*4bdc9457SAndroid Build Coastguard Worker                           XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
165*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
166*4bdc9457SAndroid Build Coastguard Worker 
167*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
168*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
169*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
170*4bdc9457SAndroid Build Coastguard Worker 
171*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
172*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
173*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_subtract);
174*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
175*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
176*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
177*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
178*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
179*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
180*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
181*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
182*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
183*4bdc9457SAndroid Build Coastguard Worker }
184*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestQS8,matches_operator_api)185*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestQS8, matches_operator_api)
186*4bdc9457SAndroid Build Coastguard Worker {
187*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = i8dist(rng);
188*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
189*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = i8dist(rng);
190*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
191*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = i8dist(rng);
192*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
193*4bdc9457SAndroid Build Coastguard Worker   const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point);
194*4bdc9457SAndroid Build Coastguard Worker   const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point);
195*4bdc9457SAndroid Build Coastguard Worker 
196*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
197*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
198*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
199*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
200*4bdc9457SAndroid Build Coastguard Worker 
201*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
202*4bdc9457SAndroid Build Coastguard Worker 
203*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
204*4bdc9457SAndroid Build Coastguard Worker 
205*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
206*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
207*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_create_subtract_nd_qs8(
208*4bdc9457SAndroid Build Coastguard Worker                           input1_zero_point, input1_scale, input2_zero_point, input2_scale, output_zero_point,
209*4bdc9457SAndroid Build Coastguard Worker                           output_scale, quantized_output_min, quantized_output_max, /*flags=*/0, &op));
210*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
211*4bdc9457SAndroid Build Coastguard Worker 
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
213*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_subtract_nd_qs8(
214*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
215*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
216*4bdc9457SAndroid Build Coastguard Worker 
217*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
218*4bdc9457SAndroid Build Coastguard Worker 
219*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
220*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
221*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
222*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
223*4bdc9457SAndroid Build Coastguard Worker 
224*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
225*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
226*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
227*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
228*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
229*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
230*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
231*4bdc9457SAndroid Build Coastguard Worker 
232*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
233*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
234*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
235*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
236*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
237*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
238*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
239*4bdc9457SAndroid Build Coastguard Worker 
240*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
241*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
242*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
243*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(),
244*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, /*external_id=*/2,
245*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
246*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
247*4bdc9457SAndroid Build Coastguard Worker 
248*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
249*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
250*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
253*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
254*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
255*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
256*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
257*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
258*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
259*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
260*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
261*4bdc9457SAndroid Build Coastguard Worker 
262*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
263*4bdc9457SAndroid Build Coastguard Worker }
264*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestQU8,matches_operator_api)265*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestQU8, matches_operator_api)
266*4bdc9457SAndroid Build Coastguard Worker {
267*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = u8dist(rng);
268*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
269*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = u8dist(rng);
270*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
271*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = u8dist(rng);
272*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
273*4bdc9457SAndroid Build Coastguard Worker   const uint8_t quantized_output_min = xnn_qu8_quantize(output_min, output_scale, output_zero_point);
274*4bdc9457SAndroid Build Coastguard Worker   const uint8_t quantized_output_max = xnn_qu8_quantize(output_max, output_scale, output_zero_point);
275*4bdc9457SAndroid Build Coastguard Worker 
276*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
277*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
278*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
279*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
280*4bdc9457SAndroid Build Coastguard Worker 
281*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
282*4bdc9457SAndroid Build Coastguard Worker 
283*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
284*4bdc9457SAndroid Build Coastguard Worker 
285*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
286*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
287*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_create_subtract_nd_qu8(
288*4bdc9457SAndroid Build Coastguard Worker                           input1_zero_point, input1_scale, input2_zero_point, input2_scale, output_zero_point,
289*4bdc9457SAndroid Build Coastguard Worker                           output_scale, quantized_output_min, quantized_output_max, /*flags=*/0, &op));
290*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
291*4bdc9457SAndroid Build Coastguard Worker 
292*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
293*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_subtract_nd_qu8(
294*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
295*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
296*4bdc9457SAndroid Build Coastguard Worker 
297*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
300*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
301*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
302*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
303*4bdc9457SAndroid Build Coastguard Worker 
304*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
305*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
306*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
307*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
308*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
309*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
310*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
311*4bdc9457SAndroid Build Coastguard Worker 
312*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
313*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
314*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
315*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
316*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
317*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
318*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
319*4bdc9457SAndroid Build Coastguard Worker 
320*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
321*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
322*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
323*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(),
324*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, /*external_id=*/2,
325*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
326*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
327*4bdc9457SAndroid Build Coastguard Worker 
328*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
329*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
330*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
331*4bdc9457SAndroid Build Coastguard Worker 
332*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
333*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
334*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
335*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
336*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
337*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
338*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
339*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
340*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
341*4bdc9457SAndroid Build Coastguard Worker 
342*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
343*4bdc9457SAndroid Build Coastguard Worker }
344*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Subtract2TestF32,matches_operator_api)345*4bdc9457SAndroid Build Coastguard Worker TEST_F(Subtract2TestF32, matches_operator_api)
346*4bdc9457SAndroid Build Coastguard Worker {
347*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
348*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
349*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), nanf(""));
350*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));
351*4bdc9457SAndroid Build Coastguard Worker 
352*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
353*4bdc9457SAndroid Build Coastguard Worker 
354*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
355*4bdc9457SAndroid Build Coastguard Worker 
356*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
357*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subtract_nd_f32(output_min, output_max, 0, &op));
358*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
359*4bdc9457SAndroid Build Coastguard Worker 
360*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
361*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_subtract_nd_f32(
362*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
363*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
364*4bdc9457SAndroid Build Coastguard Worker 
365*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
366*4bdc9457SAndroid Build Coastguard Worker 
367*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
368*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
369*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
370*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
371*4bdc9457SAndroid Build Coastguard Worker 
372*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
373*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
374*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
375*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr,
376*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
377*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
378*4bdc9457SAndroid Build Coastguard Worker 
379*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
380*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
381*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
382*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr,
383*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
384*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
385*4bdc9457SAndroid Build Coastguard Worker 
386*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
387*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
388*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
389*4bdc9457SAndroid Build Coastguard Worker     xnn_define_tensor_value(
390*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, /*external_id=*/2,
391*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
392*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
393*4bdc9457SAndroid Build Coastguard Worker 
394*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
395*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
396*4bdc9457SAndroid Build Coastguard Worker     xnn_define_subtract(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
397*4bdc9457SAndroid Build Coastguard Worker 
398*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
399*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
400*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
401*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
402*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
403*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
404*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
405*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
406*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
407*4bdc9457SAndroid Build Coastguard Worker 
408*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
409*4bdc9457SAndroid Build Coastguard Worker }
410