1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include "runtime-tester.h"
10*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack {
13*4bdc9457SAndroid Build Coastguard Worker
TEST(ADD_THEN_CLAMP,fusion)14*4bdc9457SAndroid Build Coastguard Worker TEST(ADD_THEN_CLAMP, fusion) {
15*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(4);
16*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
17*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
18*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id = 0;
19*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id = 1;
20*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 2;
21*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 3;
22*4bdc9457SAndroid Build Coastguard Worker tester
23*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input1_id)
24*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input2_id)
25*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 2, 2, 3}, intermediate_id)
26*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 2, 2, 3}, output_id)
27*4bdc9457SAndroid Build Coastguard Worker .AddAddition(input1_id, input2_id, intermediate_id)
28*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
29*4bdc9457SAndroid Build Coastguard Worker
30*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
31*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
34*4bdc9457SAndroid Build Coastguard Worker
35*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
36*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
37*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
38*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
39*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
40*4bdc9457SAndroid Build Coastguard Worker
41*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
42*4bdc9457SAndroid Build Coastguard Worker }
43*4bdc9457SAndroid Build Coastguard Worker
TEST(AVERAGE_POOLING_2D_THEN_CLAMP,fusion)44*4bdc9457SAndroid Build Coastguard Worker TEST(AVERAGE_POOLING_2D_THEN_CLAMP, fusion) {
45*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(3);
46*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
47*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
48*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
49*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
50*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 2;
51*4bdc9457SAndroid Build Coastguard Worker tester
52*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 10, 10, 3}, input_id)
53*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 9, 9, 3}, intermediate_id)
54*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 9, 9, 3}, output_id)
55*4bdc9457SAndroid Build Coastguard Worker .AddAveragePooling2D(0, 0, 0, 0, 2, 2, 1, 1, input_id, intermediate_id)
56*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
57*4bdc9457SAndroid Build Coastguard Worker
58*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
59*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
60*4bdc9457SAndroid Build Coastguard Worker
61*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
64*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
65*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
66*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
67*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
68*4bdc9457SAndroid Build Coastguard Worker
69*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
70*4bdc9457SAndroid Build Coastguard Worker }
71*4bdc9457SAndroid Build Coastguard Worker
TEST(CLAMP_THEN_CLAMP,fusion)72*4bdc9457SAndroid Build Coastguard Worker TEST(CLAMP_THEN_CLAMP, fusion) {
73*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(3);
74*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
75*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
76*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
77*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
78*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 2;
79*4bdc9457SAndroid Build Coastguard Worker tester
80*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 10, 10, 3}, input_id)
81*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 10, 10, 3}, intermediate_id)
82*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 10, 10, 3}, output_id)
83*4bdc9457SAndroid Build Coastguard Worker .AddClamp(
84*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(),
85*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(),
86*4bdc9457SAndroid Build Coastguard Worker input_id,
87*4bdc9457SAndroid Build Coastguard Worker intermediate_id)
88*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
89*4bdc9457SAndroid Build Coastguard Worker
90*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
91*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
92*4bdc9457SAndroid Build Coastguard Worker
93*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
94*4bdc9457SAndroid Build Coastguard Worker
95*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
96*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
97*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
98*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
99*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
100*4bdc9457SAndroid Build Coastguard Worker
101*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
102*4bdc9457SAndroid Build Coastguard Worker }
103*4bdc9457SAndroid Build Coastguard Worker
TEST(CONVOLUTION_2D_THEN_CLAMP,fusion)104*4bdc9457SAndroid Build Coastguard Worker TEST(CONVOLUTION_2D_THEN_CLAMP, fusion) {
105*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
106*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
107*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
108*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
109*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 1;
110*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 2;
111*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 3;
112*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
113*4bdc9457SAndroid Build Coastguard Worker tester
114*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 256, 256, 3}, input_id)
115*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, filter_id)
116*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, bias_id)
117*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 32}, intermediate_id)
118*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 128, 128, 32}, output_id)
119*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
120*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
121*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
122*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
123*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
124*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
125*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
126*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
127*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
128*4bdc9457SAndroid Build Coastguard Worker }, input_id, filter_id, bias_id, intermediate_id)
129*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
130*4bdc9457SAndroid Build Coastguard Worker
131*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
132*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
133*4bdc9457SAndroid Build Coastguard Worker
134*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
135*4bdc9457SAndroid Build Coastguard Worker
136*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
137*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
138*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
139*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
140*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
141*4bdc9457SAndroid Build Coastguard Worker
142*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
143*4bdc9457SAndroid Build Coastguard Worker }
144*4bdc9457SAndroid Build Coastguard Worker
TEST(DIVIDE_THEN_CLAMP,fusion)145*4bdc9457SAndroid Build Coastguard Worker TEST(DIVIDE_THEN_CLAMP, fusion) {
146*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(4);
147*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
148*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
149*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id = 0;
150*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id = 1;
151*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 2;
152*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 3;
153*4bdc9457SAndroid Build Coastguard Worker tester
154*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input1_id)
155*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input2_id)
156*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 2, 2, 3}, intermediate_id)
157*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 2, 2, 3}, output_id)
158*4bdc9457SAndroid Build Coastguard Worker .AddDivide(input1_id, input2_id, intermediate_id)
159*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
160*4bdc9457SAndroid Build Coastguard Worker
161*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
162*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
165*4bdc9457SAndroid Build Coastguard Worker
166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
167*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
168*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
170*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
171*4bdc9457SAndroid Build Coastguard Worker
172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
173*4bdc9457SAndroid Build Coastguard Worker }
174*4bdc9457SAndroid Build Coastguard Worker
TEST(DECONVOLUTION_2D_THEN_CLAMP,fusion)175*4bdc9457SAndroid Build Coastguard Worker TEST(DECONVOLUTION_2D_THEN_CLAMP, fusion) {
176*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
177*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
178*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
179*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
180*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 1;
181*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 2;
182*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 3;
183*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
184*4bdc9457SAndroid Build Coastguard Worker tester
185*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 128, 128, 3}, input_id)
186*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, filter_id)
187*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, bias_id)
188*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 255, 255, 32}, intermediate_id)
189*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 255, 255, 32}, output_id)
190*4bdc9457SAndroid Build Coastguard Worker .AddDeconvolution2D(
191*4bdc9457SAndroid Build Coastguard Worker DeconvolutionParams{
192*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
193*4bdc9457SAndroid Build Coastguard Worker Adjustment{0, 0},
194*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
195*4bdc9457SAndroid Build Coastguard Worker Upsampling{2, 2},
196*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
197*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
198*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
199*4bdc9457SAndroid Build Coastguard Worker /*groups_output_channels*/ 32
200*4bdc9457SAndroid Build Coastguard Worker }, input_id, filter_id, bias_id, intermediate_id)
201*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
202*4bdc9457SAndroid Build Coastguard Worker
203*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
204*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
205*4bdc9457SAndroid Build Coastguard Worker
206*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
207*4bdc9457SAndroid Build Coastguard Worker
208*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
209*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
210*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
211*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
212*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
213*4bdc9457SAndroid Build Coastguard Worker
214*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
215*4bdc9457SAndroid Build Coastguard Worker }
216*4bdc9457SAndroid Build Coastguard Worker
TEST(DEPTHWISE_CONVOLUTION_2D_THEN_CLAMP,fusion)217*4bdc9457SAndroid Build Coastguard Worker TEST(DEPTHWISE_CONVOLUTION_2D_THEN_CLAMP, fusion) {
218*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
219*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
220*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
221*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
222*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 1;
223*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 2;
224*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 3;
225*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
226*4bdc9457SAndroid Build Coastguard Worker tester
227*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 128, 128, 4}, input_id)
228*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({1, 3, 3, 4}, TensorType::kDense, filter_id)
229*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, bias_id)
230*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 4}, intermediate_id)
231*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 128, 128, 4}, output_id)
232*4bdc9457SAndroid Build Coastguard Worker .AddDepthwiseConvolution2D(
233*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams{
234*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
235*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
236*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
237*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
238*4bdc9457SAndroid Build Coastguard Worker /*depth_multiplier=*/ 1,
239*4bdc9457SAndroid Build Coastguard Worker /*input_channels=*/ 4
240*4bdc9457SAndroid Build Coastguard Worker }, input_id, filter_id, bias_id, intermediate_id)
241*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
242*4bdc9457SAndroid Build Coastguard Worker
243*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
244*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
245*4bdc9457SAndroid Build Coastguard Worker
246*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
247*4bdc9457SAndroid Build Coastguard Worker
248*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
249*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
250*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
251*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
252*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
253*4bdc9457SAndroid Build Coastguard Worker
254*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
255*4bdc9457SAndroid Build Coastguard Worker }
256*4bdc9457SAndroid Build Coastguard Worker
TEST(FULLY_CONNECTED_2D_THEN_CLAMP,fusion)257*4bdc9457SAndroid Build Coastguard Worker TEST(FULLY_CONNECTED_2D_THEN_CLAMP, fusion) {
258*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
259*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
260*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
261*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
262*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 1;
263*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 2;
264*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 3;
265*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
266*4bdc9457SAndroid Build Coastguard Worker tester
267*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({5, 3}, input_id)
268*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({7, 3}, TensorType::kDense, filter_id)
269*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({7}, TensorType::kDense, bias_id)
270*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({5, 7}, intermediate_id)
271*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({5, 7}, output_id)
272*4bdc9457SAndroid Build Coastguard Worker .AddFullyConnected(input_id, filter_id, bias_id, intermediate_id)
273*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
274*4bdc9457SAndroid Build Coastguard Worker
275*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
276*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
277*4bdc9457SAndroid Build Coastguard Worker
278*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
279*4bdc9457SAndroid Build Coastguard Worker
280*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
281*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
282*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
283*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
284*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
285*4bdc9457SAndroid Build Coastguard Worker
286*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
287*4bdc9457SAndroid Build Coastguard Worker }
288*4bdc9457SAndroid Build Coastguard Worker
TEST(MULTIPLY_THEN_CLAMP,fusion)289*4bdc9457SAndroid Build Coastguard Worker TEST(MULTIPLY_THEN_CLAMP, fusion) {
290*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(4);
291*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
292*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
293*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id = 0;
294*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id = 1;
295*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 2;
296*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 3;
297*4bdc9457SAndroid Build Coastguard Worker tester
298*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input1_id)
299*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input2_id)
300*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 2, 2, 3}, intermediate_id)
301*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 2, 2, 3}, output_id)
302*4bdc9457SAndroid Build Coastguard Worker .AddMultiply(input1_id, input2_id, intermediate_id)
303*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
304*4bdc9457SAndroid Build Coastguard Worker
305*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
306*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
307*4bdc9457SAndroid Build Coastguard Worker
308*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
309*4bdc9457SAndroid Build Coastguard Worker
310*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
311*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
312*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
313*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
314*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
315*4bdc9457SAndroid Build Coastguard Worker
316*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
317*4bdc9457SAndroid Build Coastguard Worker }
318*4bdc9457SAndroid Build Coastguard Worker
TEST(MAX_POOLING_THEN_CLAMP,fusion)319*4bdc9457SAndroid Build Coastguard Worker TEST(MAX_POOLING_THEN_CLAMP, fusion) {
320*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(3);
321*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
322*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
323*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
324*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
325*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 2;
326*4bdc9457SAndroid Build Coastguard Worker tester
327*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 10, 10, 3}, input_id)
328*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 9, 9, 3}, intermediate_id)
329*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 9, 9, 3}, output_id)
330*4bdc9457SAndroid Build Coastguard Worker .AddMaxPooling2D(0, 0, 0, 0, 2, 2, 1, 1, 1, 1, input_id, intermediate_id)
331*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
332*4bdc9457SAndroid Build Coastguard Worker
333*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
334*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
335*4bdc9457SAndroid Build Coastguard Worker
336*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
339*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
340*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
342*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
343*4bdc9457SAndroid Build Coastguard Worker
344*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
345*4bdc9457SAndroid Build Coastguard Worker }
346*4bdc9457SAndroid Build Coastguard Worker
TEST(SUBTRACT_THEN_CLAMP,fusion)347*4bdc9457SAndroid Build Coastguard Worker TEST(SUBTRACT_THEN_CLAMP, fusion) {
348*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(4);
349*4bdc9457SAndroid Build Coastguard Worker float output_min = -0.5f;
350*4bdc9457SAndroid Build Coastguard Worker float output_max = 0.5f;
351*4bdc9457SAndroid Build Coastguard Worker uint32_t input1_id = 0;
352*4bdc9457SAndroid Build Coastguard Worker uint32_t input2_id = 1;
353*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 2;
354*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 3;
355*4bdc9457SAndroid Build Coastguard Worker tester
356*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input1_id)
357*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 2, 2, 3}, input2_id)
358*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 2, 2, 3}, intermediate_id)
359*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 2, 2, 3}, output_id)
360*4bdc9457SAndroid Build Coastguard Worker .AddSubtract(input1_id, input2_id, intermediate_id)
361*4bdc9457SAndroid Build Coastguard Worker .AddClamp(output_min, output_max, intermediate_id, output_id);
362*4bdc9457SAndroid Build Coastguard Worker
363*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
364*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
365*4bdc9457SAndroid Build Coastguard Worker
366*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
367*4bdc9457SAndroid Build Coastguard Worker
368*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
369*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_min, output_min);
370*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->activation.output_max, output_max);
371*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->outputs[0], output_id);
372*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->compute_type, xnn_compute_type_invalid);
373*4bdc9457SAndroid Build Coastguard Worker
374*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
375*4bdc9457SAndroid Build Coastguard Worker }
376*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_CONVOLUTION,fusion)377*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_CONVOLUTION, fusion) {
378*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
379*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
380*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
381*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
382*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
383*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
384*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {0, 2, 4, 0};
385*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
386*4bdc9457SAndroid Build Coastguard Worker float padding_value = 0.0f;
387*4bdc9457SAndroid Build Coastguard Worker
388*4bdc9457SAndroid Build Coastguard Worker tester
389*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 254, 254, 3}, input_id)
390*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 262, 266, 3}, intermediate_id)
391*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, filter_id)
392*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, bias_id)
393*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 131, 133, 32}, output_id)
394*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
395*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
396*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
397*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
398*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
399*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
400*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
401*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
402*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
403*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
404*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id);
405*4bdc9457SAndroid Build Coastguard Worker
406*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
407*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
408*4bdc9457SAndroid Build Coastguard Worker
409*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
410*4bdc9457SAndroid Build Coastguard Worker
411*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
412*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->compute_type, xnn_compute_type_invalid);
413*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.convolution_2d.input_padding_top, 2);
414*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.convolution_2d.input_padding_left, 4);
415*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.convolution_2d.input_padding_right, 8);
416*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.convolution_2d.input_padding_bottom, 6);
417*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->outputs[0], output_id);
418*4bdc9457SAndroid Build Coastguard Worker
419*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
420*4bdc9457SAndroid Build Coastguard Worker }
421*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_CONVOLUTION,not_fused_due_to_non_zero_padding_in_n_dimension)422*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_CONVOLUTION, not_fused_due_to_non_zero_padding_in_n_dimension) {
423*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
424*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
425*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
426*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
427*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
428*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
429*4bdc9457SAndroid Build Coastguard Worker // Non-zero pre-padding in the N or C dimension.
430*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {1, 2, 4, 0};
431*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
432*4bdc9457SAndroid Build Coastguard Worker float padding_value = 0.0f;
433*4bdc9457SAndroid Build Coastguard Worker
434*4bdc9457SAndroid Build Coastguard Worker tester
435*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 254, 254, 3}, input_id)
436*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({2, 262, 266, 3}, intermediate_id)
437*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, filter_id)
438*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, bias_id)
439*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({2, 131, 133, 32}, output_id)
440*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
441*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
442*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
443*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
444*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
445*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
446*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
447*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
448*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
449*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
450*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id)
451*4bdc9457SAndroid Build Coastguard Worker .Optimize();
452*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
453*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
454*4bdc9457SAndroid Build Coastguard Worker }
455*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_CONVOLUTION,not_fused_due_to_padding_value_not_zero)456*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_CONVOLUTION, not_fused_due_to_padding_value_not_zero) {
457*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
458*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
459*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
460*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
461*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
462*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
463*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {0, 2, 4, 0};
464*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
465*4bdc9457SAndroid Build Coastguard Worker float padding_value = 1.0f;
466*4bdc9457SAndroid Build Coastguard Worker
467*4bdc9457SAndroid Build Coastguard Worker tester
468*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 254, 254, 3}, input_id)
469*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({2, 262, 266, 3}, intermediate_id)
470*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, filter_id)
471*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, bias_id)
472*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({2, 131, 133, 32}, output_id)
473*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
474*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
475*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
476*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
477*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
478*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
479*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
480*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
481*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
482*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
483*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id)
484*4bdc9457SAndroid Build Coastguard Worker .Optimize();
485*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
486*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
487*4bdc9457SAndroid Build Coastguard Worker }
488*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION,fusion)489*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION, fusion) {
490*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
491*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
492*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
493*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
494*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
495*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
496*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {0, 2, 4, 0};
497*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
498*4bdc9457SAndroid Build Coastguard Worker float padding_value = 0.0f;
499*4bdc9457SAndroid Build Coastguard Worker tester
500*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 128, 128, 4}, input_id)
501*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 136, 140, 4}, intermediate_id)
502*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({1, 3, 3, 4}, TensorType::kDense, filter_id)
503*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, bias_id)
504*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 134, 140, 4}, output_id)
505*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
506*4bdc9457SAndroid Build Coastguard Worker .AddDepthwiseConvolution2D(
507*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams{
508*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
509*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
510*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
511*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
512*4bdc9457SAndroid Build Coastguard Worker /*depth_multiplier=*/ 1,
513*4bdc9457SAndroid Build Coastguard Worker /*input_channels=*/ 4
514*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id);
515*4bdc9457SAndroid Build Coastguard Worker
516*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
517*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
518*4bdc9457SAndroid Build Coastguard Worker
519*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
520*4bdc9457SAndroid Build Coastguard Worker
521*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 1);
522*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(0)->compute_type, xnn_compute_type_invalid);
523*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.depthwise_convolution_2d.input_padding_top, 2);
524*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.depthwise_convolution_2d.input_padding_left, 4);
525*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.depthwise_convolution_2d.input_padding_right, 8);
526*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->params.depthwise_convolution_2d.input_padding_bottom, 6);
527*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.Node(1)->outputs[0], output_id);
528*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
529*4bdc9457SAndroid Build Coastguard Worker }
530*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION,not_fused_due_to_non_zero_padding_in_n_dimension)531*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION, not_fused_due_to_non_zero_padding_in_n_dimension) {
532*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
533*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
534*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
535*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
536*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
537*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
538*4bdc9457SAndroid Build Coastguard Worker // Non-zero pre-padding in the N or C dimension.
539*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {1, 2, 4, 0};
540*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
541*4bdc9457SAndroid Build Coastguard Worker float padding_value = 0.0f;
542*4bdc9457SAndroid Build Coastguard Worker tester
543*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 128, 128, 4}, input_id)
544*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({2, 136, 140, 4}, intermediate_id)
545*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({1, 3, 3, 4}, TensorType::kDense, filter_id)
546*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, bias_id)
547*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({2, 134, 140, 4}, output_id)
548*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
549*4bdc9457SAndroid Build Coastguard Worker .AddDepthwiseConvolution2D(
550*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams{
551*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
552*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
553*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
554*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
555*4bdc9457SAndroid Build Coastguard Worker /*depth_multiplier=*/ 1,
556*4bdc9457SAndroid Build Coastguard Worker /*input_channels=*/ 4
557*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id);
558*4bdc9457SAndroid Build Coastguard Worker
559*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
560*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
561*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
562*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
563*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
564*4bdc9457SAndroid Build Coastguard Worker }
565*4bdc9457SAndroid Build Coastguard Worker
TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION,not_fused_due_to_padding_value_not_zero)566*4bdc9457SAndroid Build Coastguard Worker TEST(CONSTANT_PAD_THEN_DEPTHWISE_CONVOLUTION, not_fused_due_to_padding_value_not_zero) {
567*4bdc9457SAndroid Build Coastguard Worker auto tester = RuntimeTester(5);
568*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id = 0;
569*4bdc9457SAndroid Build Coastguard Worker uint32_t intermediate_id = 1;
570*4bdc9457SAndroid Build Coastguard Worker uint32_t filter_id = 2;
571*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_id = 3;
572*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id = 4;
573*4bdc9457SAndroid Build Coastguard Worker size_t pre_paddings[4] = {0, 2, 4, 0};
574*4bdc9457SAndroid Build Coastguard Worker size_t post_paddings[4] = {0, 6, 8, 0};
575*4bdc9457SAndroid Build Coastguard Worker float padding_value = 1.0f;
576*4bdc9457SAndroid Build Coastguard Worker tester
577*4bdc9457SAndroid Build Coastguard Worker .AddInputTensorF32({1, 128, 128, 4}, input_id)
578*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 136, 140, 4}, intermediate_id)
579*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({1, 3, 3, 4}, TensorType::kDense, filter_id)
580*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, bias_id)
581*4bdc9457SAndroid Build Coastguard Worker .AddOutputTensorF32({1, 134, 140, 4}, output_id)
582*4bdc9457SAndroid Build Coastguard Worker .AddConstantPad(pre_paddings, post_paddings, padding_value, input_id, intermediate_id)
583*4bdc9457SAndroid Build Coastguard Worker .AddDepthwiseConvolution2D(
584*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams{
585*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
586*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
587*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
588*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
589*4bdc9457SAndroid Build Coastguard Worker /*depth_multiplier=*/ 1,
590*4bdc9457SAndroid Build Coastguard Worker /*input_channels=*/ 4
591*4bdc9457SAndroid Build Coastguard Worker }, intermediate_id, filter_id, bias_id, output_id);
592*4bdc9457SAndroid Build Coastguard Worker
593*4bdc9457SAndroid Build Coastguard Worker std::vector<float> unoptimized_output = tester.RunWithoutFusion<float>();
594*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
595*4bdc9457SAndroid Build Coastguard Worker std::vector<float> optimized_output = tester.RunWithFusion<float>();
596*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.NumOperators(), 2);
597*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(unoptimized_output, optimized_output);
598*4bdc9457SAndroid Build Coastguard Worker }
599*4bdc9457SAndroid Build Coastguard Worker
600*4bdc9457SAndroid Build Coastguard Worker } // namespace xnnpack
601