xref: /aosp_15_r20/external/XNNPACK/test/multiply2.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
7*4bdc9457SAndroid Build Coastguard Worker #include <array>
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>
10*4bdc9457SAndroid Build Coastguard Worker #include <memory>
11*4bdc9457SAndroid Build Coastguard Worker #include <vector>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-binary-tester.h"
19*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker using Multiply2TestQS8 = BinaryTest<int8_t>;
22*4bdc9457SAndroid Build Coastguard Worker using Multiply2TestQU8 = BinaryTest<uint8_t>;
23*4bdc9457SAndroid Build Coastguard Worker using Multiply2TestF32 = BinaryTest<float>;
24*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestQS8,define)25*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestQS8, define)
26*4bdc9457SAndroid Build Coastguard Worker {
27*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = i8dist(rng);
28*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
29*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = i8dist(rng);
30*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
31*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = i8dist(rng);
32*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
33*4bdc9457SAndroid Build Coastguard Worker 
34*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
35*4bdc9457SAndroid Build Coastguard Worker 
36*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
37*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
38*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
39*4bdc9457SAndroid Build Coastguard Worker 
40*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
41*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
42*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
43*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
44*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input1_id));
45*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
46*4bdc9457SAndroid Build Coastguard Worker 
47*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
48*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
49*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
50*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
51*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input2_id));
52*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
53*4bdc9457SAndroid Build Coastguard Worker 
54*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
55*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
56*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
57*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(), output_dims.data(), nullptr,
58*4bdc9457SAndroid Build Coastguard Worker                           XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
59*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
60*4bdc9457SAndroid Build Coastguard Worker 
61*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
62*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
63*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
64*4bdc9457SAndroid Build Coastguard Worker 
65*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
66*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
67*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_multiply2);
68*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
69*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
70*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
71*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
72*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
73*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
74*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
75*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
76*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
77*4bdc9457SAndroid Build Coastguard Worker }
78*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestQU8,define)79*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestQU8, define)
80*4bdc9457SAndroid Build Coastguard Worker {
81*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = u8dist(rng);
82*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
83*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = u8dist(rng);
84*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
85*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = u8dist(rng);
86*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
87*4bdc9457SAndroid Build Coastguard Worker 
88*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
91*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
92*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
93*4bdc9457SAndroid Build Coastguard Worker 
94*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
95*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
96*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
97*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
98*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input1_id));
99*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
100*4bdc9457SAndroid Build Coastguard Worker 
101*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
102*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
103*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
104*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
105*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input2_id));
106*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
107*4bdc9457SAndroid Build Coastguard Worker 
108*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
109*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
110*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
111*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(), output_dims.data(), nullptr,
112*4bdc9457SAndroid Build Coastguard Worker                           XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
113*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
116*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
117*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
118*4bdc9457SAndroid Build Coastguard Worker 
119*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
120*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
121*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_multiply2);
122*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
123*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
124*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
125*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
126*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
127*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
128*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
129*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
130*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
131*4bdc9457SAndroid Build Coastguard Worker }
132*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestF32,define)133*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestF32, define)
134*4bdc9457SAndroid Build Coastguard Worker {
135*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
136*4bdc9457SAndroid Build Coastguard Worker 
137*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
138*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
139*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
140*4bdc9457SAndroid Build Coastguard Worker 
141*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
142*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
143*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
144*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr,
145*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input1_id));
146*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
147*4bdc9457SAndroid Build Coastguard Worker 
148*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
149*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
150*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
151*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr,
152*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/0, &input2_id));
153*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
154*4bdc9457SAndroid Build Coastguard Worker 
155*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
156*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
157*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
158*4bdc9457SAndroid Build Coastguard Worker     xnn_define_tensor_value(
159*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &output_id));
160*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
161*4bdc9457SAndroid Build Coastguard Worker 
162*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
163*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
164*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
165*4bdc9457SAndroid Build Coastguard Worker 
166*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
167*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
168*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_multiply2);
169*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
170*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_min, output_min);
171*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->activation.output_max, output_max);
172*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 2);
173*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
174*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
175*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
176*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
177*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
178*4bdc9457SAndroid Build Coastguard Worker }
179*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestQS8,matches_operator_api)180*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestQS8, matches_operator_api)
181*4bdc9457SAndroid Build Coastguard Worker {
182*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = i8dist(rng);
183*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
184*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = i8dist(rng);
185*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
186*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = i8dist(rng);
187*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
188*4bdc9457SAndroid Build Coastguard Worker   const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point);
189*4bdc9457SAndroid Build Coastguard Worker   const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point);
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
192*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
193*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
194*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
195*4bdc9457SAndroid Build Coastguard Worker 
196*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
197*4bdc9457SAndroid Build Coastguard Worker 
198*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
199*4bdc9457SAndroid Build Coastguard Worker 
200*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
201*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
202*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_create_multiply_nd_qs8(
203*4bdc9457SAndroid Build Coastguard Worker                           input1_zero_point, input1_scale, input2_zero_point, input2_scale, output_zero_point,
204*4bdc9457SAndroid Build Coastguard Worker                           output_scale, quantized_output_min, quantized_output_max, /*flags=*/0, &op));
205*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
206*4bdc9457SAndroid Build Coastguard Worker 
207*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
208*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_multiply_nd_qs8(
209*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
210*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
211*4bdc9457SAndroid Build Coastguard Worker 
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
213*4bdc9457SAndroid Build Coastguard Worker 
214*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
215*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
216*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
217*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
218*4bdc9457SAndroid Build Coastguard Worker 
219*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
220*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
221*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
222*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
223*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
224*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
225*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
226*4bdc9457SAndroid Build Coastguard Worker 
227*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
228*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
229*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
230*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
231*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
232*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
233*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
234*4bdc9457SAndroid Build Coastguard Worker 
235*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
236*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
237*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
238*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(),
239*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, /*external_id=*/2,
240*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
241*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
244*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
245*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
246*4bdc9457SAndroid Build Coastguard Worker 
247*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
248*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
249*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
250*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
251*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
252*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
253*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
254*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
255*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
256*4bdc9457SAndroid Build Coastguard Worker 
257*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
258*4bdc9457SAndroid Build Coastguard Worker }
259*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestQU8,matches_operator_api)260*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestQU8, matches_operator_api)
261*4bdc9457SAndroid Build Coastguard Worker {
262*4bdc9457SAndroid Build Coastguard Worker   const int32_t input1_zero_point = u8dist(rng);
263*4bdc9457SAndroid Build Coastguard Worker   const float input1_scale = scale_dist(rng);
264*4bdc9457SAndroid Build Coastguard Worker   const int32_t input2_zero_point = u8dist(rng);
265*4bdc9457SAndroid Build Coastguard Worker   const float input2_scale = scale_dist(rng);
266*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_zero_point = u8dist(rng);
267*4bdc9457SAndroid Build Coastguard Worker   const float output_scale = scale_dist(rng);
268*4bdc9457SAndroid Build Coastguard Worker   const uint8_t quantized_output_min = xnn_qu8_quantize(output_min, output_scale, output_zero_point);
269*4bdc9457SAndroid Build Coastguard Worker   const uint8_t quantized_output_max = xnn_qu8_quantize(output_max, output_scale, output_zero_point);
270*4bdc9457SAndroid Build Coastguard Worker 
271*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
272*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
273*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
274*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
275*4bdc9457SAndroid Build Coastguard Worker 
276*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
277*4bdc9457SAndroid Build Coastguard Worker 
278*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
279*4bdc9457SAndroid Build Coastguard Worker 
280*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
281*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
282*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_create_multiply_nd_qu8(
283*4bdc9457SAndroid Build Coastguard Worker                           input1_zero_point, input1_scale, input2_zero_point, input2_scale, output_zero_point,
284*4bdc9457SAndroid Build Coastguard Worker                           output_scale, quantized_output_min, quantized_output_max, /*flags=*/0, &op));
285*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
286*4bdc9457SAndroid Build Coastguard Worker 
287*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
288*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_multiply_nd_qu8(
289*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
290*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
291*4bdc9457SAndroid Build Coastguard Worker 
292*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
293*4bdc9457SAndroid Build Coastguard Worker 
294*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
295*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
296*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
297*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
300*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
301*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
302*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
303*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input1_zero_point, input1_scale, input1_dims.size(), input1_dims.data(), nullptr,
304*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
305*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
306*4bdc9457SAndroid Build Coastguard Worker 
307*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
308*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
309*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
310*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
311*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, input2_zero_point, input2_scale, input2_dims.size(), input2_dims.data(), nullptr,
312*4bdc9457SAndroid Build Coastguard Worker       /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
313*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
314*4bdc9457SAndroid Build Coastguard Worker 
315*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
316*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
317*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_quantized_tensor_value(
318*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(),
319*4bdc9457SAndroid Build Coastguard Worker                           output_dims.data(), nullptr, /*external_id=*/2,
320*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
321*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
322*4bdc9457SAndroid Build Coastguard Worker 
323*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
324*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
325*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
326*4bdc9457SAndroid Build Coastguard Worker 
327*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
328*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
329*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
330*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
331*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
332*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
333*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
334*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
335*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
336*4bdc9457SAndroid Build Coastguard Worker 
337*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
338*4bdc9457SAndroid Build Coastguard Worker }
339*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Multiply2TestF32,matches_operator_api)340*4bdc9457SAndroid Build Coastguard Worker TEST_F(Multiply2TestF32, matches_operator_api)
341*4bdc9457SAndroid Build Coastguard Worker {
342*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
343*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
344*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), nanf(""));
345*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));
346*4bdc9457SAndroid Build Coastguard Worker 
347*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op = nullptr;
350*4bdc9457SAndroid Build Coastguard Worker 
351*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
352*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_multiply_nd_f32(output_min, output_max, 0, &op));
353*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op(op, xnn_delete_operator);
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
356*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_multiply_nd_f32(
357*4bdc9457SAndroid Build Coastguard Worker                           op, input1_dims.size(), input1_dims.data(), input2_dims.size(), input2_dims.data(),
358*4bdc9457SAndroid Build Coastguard Worker                           input1.data(), input2.data(), operator_output.data(), nullptr /* thread pool */));
359*4bdc9457SAndroid Build Coastguard Worker 
360*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op, nullptr /* thread pool */));
361*4bdc9457SAndroid Build Coastguard Worker 
362*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
363*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
364*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(3, /*flags=*/0, &subgraph));
365*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
366*4bdc9457SAndroid Build Coastguard Worker 
367*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id = XNN_INVALID_NODE_ID;
368*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
369*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
370*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr,
371*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
372*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
373*4bdc9457SAndroid Build Coastguard Worker 
374*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id = XNN_INVALID_NODE_ID;
375*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
376*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
377*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr,
378*4bdc9457SAndroid Build Coastguard Worker                           /*external_id=*/1, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
379*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
380*4bdc9457SAndroid Build Coastguard Worker 
381*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_NODE_ID;
382*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
383*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
384*4bdc9457SAndroid Build Coastguard Worker     xnn_define_tensor_value(
385*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, /*external_id=*/2,
386*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
387*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
388*4bdc9457SAndroid Build Coastguard Worker 
389*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
390*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
391*4bdc9457SAndroid Build Coastguard Worker     xnn_define_multiply2(subgraph, output_min, output_max, input1_id, input2_id, output_id, /*flags=*/0));
392*4bdc9457SAndroid Build Coastguard Worker 
393*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
394*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
395*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
396*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
397*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 3> external = {
398*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
399*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
400*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
401*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
402*4bdc9457SAndroid Build Coastguard Worker 
403*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
404*4bdc9457SAndroid Build Coastguard Worker }
405*4bdc9457SAndroid Build Coastguard Worker 
406