1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
7*4bdc9457SAndroid Build Coastguard Worker
8*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-tester.h"
9*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
10*4bdc9457SAndroid Build Coastguard Worker
11*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack {
12*4bdc9457SAndroid Build Coastguard Worker
TEST(SUBGRAPH_NCHW,single_conv)13*4bdc9457SAndroid Build Coastguard Worker TEST(SUBGRAPH_NCHW, single_conv) {
14*4bdc9457SAndroid Build Coastguard Worker auto tester = SubgraphTester(4);
15*4bdc9457SAndroid Build Coastguard Worker tester
16*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 256, 256, 3}, 0)
17*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, 1)
18*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, 2)
19*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 32}, 3)
20*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
21*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
22*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
23*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
24*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
25*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
26*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
27*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
28*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
29*4bdc9457SAndroid Build Coastguard Worker }, 0, 1, 2, 3)
30*4bdc9457SAndroid Build Coastguard Worker .Optimize()
31*4bdc9457SAndroid Build Coastguard Worker .RewriteForNchw();
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(0), xnn_layout_type_nhwc);
34*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(3), xnn_layout_type_nhwc);
35*4bdc9457SAndroid Build Coastguard Worker }
36*4bdc9457SAndroid Build Coastguard Worker
TEST(SUBGRAPH_NCHW,single_conv_and_global_average_pooling)37*4bdc9457SAndroid Build Coastguard Worker TEST(SUBGRAPH_NCHW, single_conv_and_global_average_pooling) {
38*4bdc9457SAndroid Build Coastguard Worker auto tester = SubgraphTester(5);
39*4bdc9457SAndroid Build Coastguard Worker tester
40*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 256, 256, 3}, 0)
41*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32, 3, 3, 3}, TensorType::kDense, 1)
42*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({32}, TensorType::kDense, 2)
43*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 32}, 3)
44*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({32}, 4)
45*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
46*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
47*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
48*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
49*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
50*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
51*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
52*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
53*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 32,
54*4bdc9457SAndroid Build Coastguard Worker }, 0, 1, 2, 3)
55*4bdc9457SAndroid Build Coastguard Worker .AddGlobalAveragePooling(3, 4)
56*4bdc9457SAndroid Build Coastguard Worker .Optimize()
57*4bdc9457SAndroid Build Coastguard Worker .RewriteForNchw();
58*4bdc9457SAndroid Build Coastguard Worker
59*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(0), xnn_layout_type_nhwc);
60*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(3), xnn_layout_type_nhwc);
61*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(4), xnn_layout_type_nhwc);
62*4bdc9457SAndroid Build Coastguard Worker }
63*4bdc9457SAndroid Build Coastguard Worker
TEST(SUBGRAPH_NCHW,pixelwise_conv_sandwich)64*4bdc9457SAndroid Build Coastguard Worker TEST(SUBGRAPH_NCHW, pixelwise_conv_sandwich) {
65*4bdc9457SAndroid Build Coastguard Worker auto tester = SubgraphTester(8);
66*4bdc9457SAndroid Build Coastguard Worker tester
67*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 256, 256, 3}, 0)
68*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8, 3, 3, 3}, TensorType::kDense, 1)
69*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8}, TensorType::kDense, 2)
70*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 8}, 3)
71*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4, 1, 1, 8}, TensorType::kSparse, 4)
72*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, 5)
73*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 4}, 6)
74*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 4}, 7)
75*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
76*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
77*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
78*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
79*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
80*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
81*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
82*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
83*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 8
84*4bdc9457SAndroid Build Coastguard Worker }, 0, 1, 2, 3)
85*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
86*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
87*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
88*4bdc9457SAndroid Build Coastguard Worker Kernel{1, 1},
89*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
90*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
91*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
92*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 8,
93*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 4
94*4bdc9457SAndroid Build Coastguard Worker }, 3, 4, 5, 6)
95*4bdc9457SAndroid Build Coastguard Worker .AddGlobalAveragePooling(6, 7)
96*4bdc9457SAndroid Build Coastguard Worker .Optimize()
97*4bdc9457SAndroid Build Coastguard Worker .RewriteForNchw();
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(0), xnn_layout_type_nhwc);
100*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(3), xnn_layout_type_nchw);
101*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(6), xnn_layout_type_nchw);
102*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(7), xnn_layout_type_nhwc);
103*4bdc9457SAndroid Build Coastguard Worker }
104*4bdc9457SAndroid Build Coastguard Worker
TEST(SUBGRAPH_NCHW,bottleneck)105*4bdc9457SAndroid Build Coastguard Worker TEST(SUBGRAPH_NCHW, bottleneck) {
106*4bdc9457SAndroid Build Coastguard Worker auto tester = SubgraphTester(15);
107*4bdc9457SAndroid Build Coastguard Worker tester
108*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 256, 256, 3}, 0)
109*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8, 3, 3, 3}, TensorType::kDense, 1)
110*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8}, TensorType::kDense, 2)
111*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 8}, 3)
112*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4, 1, 1, 8}, TensorType::kSparse, 4)
113*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, 5)
114*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 4}, 6)
115*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({1, 3, 3, 4}, TensorType::kDense, 7)
116*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({4}, TensorType::kDense, 8)
117*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 4}, 9)
118*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8, 1, 1, 4}, TensorType::kSparse, 10)
119*4bdc9457SAndroid Build Coastguard Worker .AddStaticTensorF32({8}, TensorType::kDense, 11)
120*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 8}, 12)
121*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 8}, 13)
122*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 128, 128, 8}, 13)
123*4bdc9457SAndroid Build Coastguard Worker .AddDynamicTensorF32({1, 8}, 14)
124*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
125*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
126*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
127*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
128*4bdc9457SAndroid Build Coastguard Worker Subsampling{2, 2},
129*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
130*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
131*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 3,
132*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 8
133*4bdc9457SAndroid Build Coastguard Worker }, 0, 1, 2, 3)
134*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
135*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
136*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
137*4bdc9457SAndroid Build Coastguard Worker Kernel{1, 1},
138*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
139*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
140*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
141*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 8,
142*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 4
143*4bdc9457SAndroid Build Coastguard Worker }, 3, 4, 5, 6)
144*4bdc9457SAndroid Build Coastguard Worker .AddDepthwiseConvolution2D(
145*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams{
146*4bdc9457SAndroid Build Coastguard Worker Padding{1, 1, 1, 1},
147*4bdc9457SAndroid Build Coastguard Worker Kernel{3, 3},
148*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
149*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
150*4bdc9457SAndroid Build Coastguard Worker /*depth_multiplier=*/ 1,
151*4bdc9457SAndroid Build Coastguard Worker /*input_channels=*/ 4
152*4bdc9457SAndroid Build Coastguard Worker }, 6, 7, 8, 9)
153*4bdc9457SAndroid Build Coastguard Worker .AddConvolution2D(
154*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams{
155*4bdc9457SAndroid Build Coastguard Worker Padding{0, 0, 0, 0},
156*4bdc9457SAndroid Build Coastguard Worker Kernel{1, 1},
157*4bdc9457SAndroid Build Coastguard Worker Subsampling{1, 1},
158*4bdc9457SAndroid Build Coastguard Worker Dilation{1, 1},
159*4bdc9457SAndroid Build Coastguard Worker /*groups=*/ 1,
160*4bdc9457SAndroid Build Coastguard Worker /*group_input_channels=*/ 8,
161*4bdc9457SAndroid Build Coastguard Worker /*group_output_channels=*/ 4
162*4bdc9457SAndroid Build Coastguard Worker }, 9, 10, 11, 12)
163*4bdc9457SAndroid Build Coastguard Worker .AddAddition(3, 12, 13)
164*4bdc9457SAndroid Build Coastguard Worker .AddGlobalAveragePooling(13, 14)
165*4bdc9457SAndroid Build Coastguard Worker .Optimize()
166*4bdc9457SAndroid Build Coastguard Worker .RewriteForNchw();
167*4bdc9457SAndroid Build Coastguard Worker
168*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(0), xnn_layout_type_nhwc);
169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(3), xnn_layout_type_nchw);
170*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(6), xnn_layout_type_nchw);
171*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(9), xnn_layout_type_nchw);
172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(12), xnn_layout_type_nchw);
173*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(13), xnn_layout_type_nchw);
174*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(tester.GetLayout(14), xnn_layout_type_nhwc);
175*4bdc9457SAndroid Build Coastguard Worker }
176*4bdc9457SAndroid Build Coastguard Worker
177*4bdc9457SAndroid Build Coastguard Worker } // namespace xnnpack
178