1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
7*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
14*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_GHW_W,primary_tile_eq_kernel_size)15*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_GHW_W, primary_tile_eq_kernel_size) {
16*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
17*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
18*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
19*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
20*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
21*4bdc9457SAndroid Build Coastguard Worker
22*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
23*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
24*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
25*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
28*4bdc9457SAndroid Build Coastguard Worker
29*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
30*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
31*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
32*4bdc9457SAndroid Build Coastguard Worker };
33*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_ghw_w(
34*4bdc9457SAndroid Build Coastguard Worker primary_tile,
35*4bdc9457SAndroid Build Coastguard Worker h,
36*4bdc9457SAndroid Build Coastguard Worker w,
37*4bdc9457SAndroid Build Coastguard Worker c,
38*4bdc9457SAndroid Build Coastguard Worker cr,
39*4bdc9457SAndroid Build Coastguard Worker k.data(),
40*4bdc9457SAndroid Build Coastguard Worker b.data(),
41*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
42*4bdc9457SAndroid Build Coastguard Worker 0,
43*4bdc9457SAndroid Build Coastguard Worker ¶ms);
44*4bdc9457SAndroid Build Coastguard Worker
45*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
46*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 48387);
47*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
48*4bdc9457SAndroid Build Coastguard Worker // bias first
49*4bdc9457SAndroid Build Coastguard Worker // 48387 + 0 - (2 + 3 + 4) * 127 = 47,244 = 0xB88C
50*4bdc9457SAndroid Build Coastguard Worker 0x8C, 0xB8, 0, 0,
51*4bdc9457SAndroid Build Coastguard Worker // 48387 + 1 - (5 + 6 + 7) * 127 = 46,102 = 0xB416
52*4bdc9457SAndroid Build Coastguard Worker 0x16, 0xB4, 0, 0,
53*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
54*4bdc9457SAndroid Build Coastguard Worker 2, 5,
55*4bdc9457SAndroid Build Coastguard Worker 3, 6,
56*4bdc9457SAndroid Build Coastguard Worker 4, 7,
57*4bdc9457SAndroid Build Coastguard Worker };
58*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
59*4bdc9457SAndroid Build Coastguard Worker }
60*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_GHW_W,primary_tile_eq_kernel_size_channels_gt_cr)61*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_GHW_W, primary_tile_eq_kernel_size_channels_gt_cr) {
62*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
63*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
64*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
65*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
66*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
67*4bdc9457SAndroid Build Coastguard Worker
68*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
69*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
70*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
71*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7,
72*4bdc9457SAndroid Build Coastguard Worker // 8, 9, 10,
73*4bdc9457SAndroid Build Coastguard Worker // 11, 12, 13,
74*4bdc9457SAndroid Build Coastguard Worker // 14, 15, 16,
75*4bdc9457SAndroid Build Coastguard Worker // 17, 18, 19]
76*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
77*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
78*4bdc9457SAndroid Build Coastguard Worker
79*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
80*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
81*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
82*4bdc9457SAndroid Build Coastguard Worker };
83*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_ghw_w(
84*4bdc9457SAndroid Build Coastguard Worker primary_tile,
85*4bdc9457SAndroid Build Coastguard Worker h,
86*4bdc9457SAndroid Build Coastguard Worker w,
87*4bdc9457SAndroid Build Coastguard Worker c,
88*4bdc9457SAndroid Build Coastguard Worker cr,
89*4bdc9457SAndroid Build Coastguard Worker k.data(),
90*4bdc9457SAndroid Build Coastguard Worker b.data(),
91*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
92*4bdc9457SAndroid Build Coastguard Worker 0,
93*4bdc9457SAndroid Build Coastguard Worker ¶ms);
94*4bdc9457SAndroid Build Coastguard Worker
95*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
96*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 48387);
97*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
98*4bdc9457SAndroid Build Coastguard Worker // cr blocks
99*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
100*4bdc9457SAndroid Build Coastguard Worker // 48387 + 0 - (5 + 6 + 7) * 127 = 46,101 = 0xB415
101*4bdc9457SAndroid Build Coastguard Worker 0x15, 0xB4, 0, 0,
102*4bdc9457SAndroid Build Coastguard Worker // 48387 + 1 - (8 + 9 + 10) * 127 = 44,959 = 0xAF9F
103*4bdc9457SAndroid Build Coastguard Worker 0x9F, 0xAF, 0, 0,
104*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
105*4bdc9457SAndroid Build Coastguard Worker 5, 8, 6, 9, 7, 10,
106*4bdc9457SAndroid Build Coastguard Worker // bias again
107*4bdc9457SAndroid Build Coastguard Worker // 48387 + 2 - (11 + 12 + 13) * 127 = 43,817 = 0xAB29
108*4bdc9457SAndroid Build Coastguard Worker 0x29, 0xAB, 0, 0,
109*4bdc9457SAndroid Build Coastguard Worker // 48387 + 3 - (14 + 15 + 16) * 127 = 42,675 = 0xA6B3
110*4bdc9457SAndroid Build Coastguard Worker 0xB3, 0xA6, 0, 0,
111*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
112*4bdc9457SAndroid Build Coastguard Worker 11, 14, 12, 15, 13, 16,
113*4bdc9457SAndroid Build Coastguard Worker // bias again
114*4bdc9457SAndroid Build Coastguard Worker // 48387 + 4 - (17 + 18 + 19) * 127 = 41,533 = 0xA23D
115*4bdc9457SAndroid Build Coastguard Worker 0x3D, 0xA2, 0, 0,
116*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
117*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
118*4bdc9457SAndroid Build Coastguard Worker 17, 0, 18, 0, 19, 0,
119*4bdc9457SAndroid Build Coastguard Worker };
120*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
121*4bdc9457SAndroid Build Coastguard Worker }
122*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_GHW_W,primary_tile_gt_kernel_size)123*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_GHW_W, primary_tile_gt_kernel_size) {
124*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
125*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
126*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
127*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
128*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
129*4bdc9457SAndroid Build Coastguard Worker
130*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
131*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
132*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
133*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
134*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
135*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
136*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
137*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
138*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
139*4bdc9457SAndroid Build Coastguard Worker
140*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
141*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
142*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
143*4bdc9457SAndroid Build Coastguard Worker };
144*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_ghw_w(
145*4bdc9457SAndroid Build Coastguard Worker primary_tile,
146*4bdc9457SAndroid Build Coastguard Worker h,
147*4bdc9457SAndroid Build Coastguard Worker w,
148*4bdc9457SAndroid Build Coastguard Worker c,
149*4bdc9457SAndroid Build Coastguard Worker cr,
150*4bdc9457SAndroid Build Coastguard Worker k.data(),
151*4bdc9457SAndroid Build Coastguard Worker b.data(),
152*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
153*4bdc9457SAndroid Build Coastguard Worker 0,
154*4bdc9457SAndroid Build Coastguard Worker ¶ms);
155*4bdc9457SAndroid Build Coastguard Worker
156*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
157*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 64516);
158*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
159*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
160*4bdc9457SAndroid Build Coastguard Worker // 64516 + 0 - (2 + 3 + 4 + 5) * 127 = 62,738 = 0xF512
161*4bdc9457SAndroid Build Coastguard Worker 0x12, 0xF5, 0, 0,
162*4bdc9457SAndroid Build Coastguard Worker // 64516 + 1 - (6 + 7 + 8 + 9) * 127 = 60,707 = 0xED23
163*4bdc9457SAndroid Build Coastguard Worker 0x23, 0xED, 0, 0,
164*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
165*4bdc9457SAndroid Build Coastguard Worker 2, 6,
166*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
167*4bdc9457SAndroid Build Coastguard Worker 4, 8, 3, 7, 5, 9,
168*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
169*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
170*4bdc9457SAndroid Build Coastguard Worker };
171*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
172*4bdc9457SAndroid Build Coastguard Worker }
173*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_GHW_W,primary_tile_gt_kernel_size_channels_gt_cr)174*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_GHW_W, primary_tile_gt_kernel_size_channels_gt_cr) {
175*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
176*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
177*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
178*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
179*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
180*4bdc9457SAndroid Build Coastguard Worker
181*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
182*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
183*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
184*4bdc9457SAndroid Build Coastguard Worker // 5, 6,
185*4bdc9457SAndroid Build Coastguard Worker // 7, 8,
186*4bdc9457SAndroid Build Coastguard Worker // 9, 10,
187*4bdc9457SAndroid Build Coastguard Worker // 11, 12,
188*4bdc9457SAndroid Build Coastguard Worker // 13, 14,
189*4bdc9457SAndroid Build Coastguard Worker // 15, 16,
190*4bdc9457SAndroid Build Coastguard Worker // 17, 18,
191*4bdc9457SAndroid Build Coastguard Worker // 19, 20,
192*4bdc9457SAndroid Build Coastguard Worker // 21, 22,
193*4bdc9457SAndroid Build Coastguard Worker // 23, 24]
194*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
195*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
196*4bdc9457SAndroid Build Coastguard Worker
197*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
198*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
199*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
200*4bdc9457SAndroid Build Coastguard Worker };
201*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_ghw_w(
202*4bdc9457SAndroid Build Coastguard Worker primary_tile,
203*4bdc9457SAndroid Build Coastguard Worker h,
204*4bdc9457SAndroid Build Coastguard Worker w,
205*4bdc9457SAndroid Build Coastguard Worker c,
206*4bdc9457SAndroid Build Coastguard Worker cr,
207*4bdc9457SAndroid Build Coastguard Worker k.data(),
208*4bdc9457SAndroid Build Coastguard Worker b.data(),
209*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
210*4bdc9457SAndroid Build Coastguard Worker 0,
211*4bdc9457SAndroid Build Coastguard Worker ¶ms);
212*4bdc9457SAndroid Build Coastguard Worker
213*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
214*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 64516);
215*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
216*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
217*4bdc9457SAndroid Build Coastguard Worker // 64516 + 0 - (5 + 6 + 7 + 8) * 127 = 61,214 = 0xEF1E
218*4bdc9457SAndroid Build Coastguard Worker 0x1E, 0xEF, 0, 0,
219*4bdc9457SAndroid Build Coastguard Worker // 64516 + 1 - (9 + 10 + 11 + 12) * 127 = 59,183 = 0xE72F
220*4bdc9457SAndroid Build Coastguard Worker 0x2F, 0xE7, 0, 0,
221*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
222*4bdc9457SAndroid Build Coastguard Worker 5, 9,
223*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
224*4bdc9457SAndroid Build Coastguard Worker 7, 11,
225*4bdc9457SAndroid Build Coastguard Worker 6, 10,
226*4bdc9457SAndroid Build Coastguard Worker 8, 12,
227*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
228*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
229*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
230*4bdc9457SAndroid Build Coastguard Worker // 64516 + 2 - (13 + 14 + 15 + 16) * 127 = 57,152 = 0xDF40
231*4bdc9457SAndroid Build Coastguard Worker 0x40, 0xDF, 0, 0,
232*4bdc9457SAndroid Build Coastguard Worker // 64516 + 3 - (17 + 18 + 19 + 20) * 127 = 55,121 = 0xD751
233*4bdc9457SAndroid Build Coastguard Worker 0x51, 0xD7, 0, 0,
234*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
235*4bdc9457SAndroid Build Coastguard Worker 13, 17, 15, 19, 14, 18, 16, 20,
236*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
237*4bdc9457SAndroid Build Coastguard Worker // bias
238*4bdc9457SAndroid Build Coastguard Worker // 64516 + 4 - (21 + 22 + 23 + 24) * 127 = 53,090 = 0xCF62
239*4bdc9457SAndroid Build Coastguard Worker 0x62, 0xCF, 0, 0,
240*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
241*4bdc9457SAndroid Build Coastguard Worker // weights
242*4bdc9457SAndroid Build Coastguard Worker 21, 0, 23, 0, 22, 0, 24, 0,
243*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
244*4bdc9457SAndroid Build Coastguard Worker };
245*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
246*4bdc9457SAndroid Build Coastguard Worker }
247*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_HWG_W,primary_tile_eq_kernel_size)248*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_HWG_W, primary_tile_eq_kernel_size) {
249*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
250*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
251*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
252*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
253*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
254*4bdc9457SAndroid Build Coastguard Worker
255*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
256*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
257*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
258*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
259*4bdc9457SAndroid Build Coastguard Worker
260*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
261*4bdc9457SAndroid Build Coastguard Worker
262*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
263*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
264*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
265*4bdc9457SAndroid Build Coastguard Worker };
266*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_hwg_w(
267*4bdc9457SAndroid Build Coastguard Worker primary_tile,
268*4bdc9457SAndroid Build Coastguard Worker h,
269*4bdc9457SAndroid Build Coastguard Worker w,
270*4bdc9457SAndroid Build Coastguard Worker c,
271*4bdc9457SAndroid Build Coastguard Worker cr,
272*4bdc9457SAndroid Build Coastguard Worker k.data(),
273*4bdc9457SAndroid Build Coastguard Worker b.data(),
274*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
275*4bdc9457SAndroid Build Coastguard Worker 0,
276*4bdc9457SAndroid Build Coastguard Worker ¶ms);
277*4bdc9457SAndroid Build Coastguard Worker
278*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
279*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 48387);
280*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
281*4bdc9457SAndroid Build Coastguard Worker // bias first
282*4bdc9457SAndroid Build Coastguard Worker // 48387 + 0 - (2 + 4 + 6) * 127 = 46,863 = 0xB70F
283*4bdc9457SAndroid Build Coastguard Worker 0x0F, 0xB7, 0, 0,
284*4bdc9457SAndroid Build Coastguard Worker // 48387 + 1 - (3 + 5 + 7) * 127 = 46,483 = 0xB593
285*4bdc9457SAndroid Build Coastguard Worker 0x93, 0xB5, 0, 0,
286*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
287*4bdc9457SAndroid Build Coastguard Worker 2, 3,
288*4bdc9457SAndroid Build Coastguard Worker 4, 5,
289*4bdc9457SAndroid Build Coastguard Worker 6, 7,
290*4bdc9457SAndroid Build Coastguard Worker };
291*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
292*4bdc9457SAndroid Build Coastguard Worker }
293*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_HWG_W,primary_tile_eq_kernel_size_channels_gt_cr)294*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_HWG_W, primary_tile_eq_kernel_size_channels_gt_cr) {
295*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
296*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
297*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
298*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
299*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
300*4bdc9457SAndroid Build Coastguard Worker
301*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
302*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
303*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
304*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
305*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
306*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19]
307*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
308*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
309*4bdc9457SAndroid Build Coastguard Worker
310*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
311*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
312*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
313*4bdc9457SAndroid Build Coastguard Worker };
314*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_hwg_w(
315*4bdc9457SAndroid Build Coastguard Worker primary_tile,
316*4bdc9457SAndroid Build Coastguard Worker h,
317*4bdc9457SAndroid Build Coastguard Worker w,
318*4bdc9457SAndroid Build Coastguard Worker c,
319*4bdc9457SAndroid Build Coastguard Worker cr,
320*4bdc9457SAndroid Build Coastguard Worker k.data(),
321*4bdc9457SAndroid Build Coastguard Worker b.data(),
322*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
323*4bdc9457SAndroid Build Coastguard Worker 0,
324*4bdc9457SAndroid Build Coastguard Worker ¶ms);
325*4bdc9457SAndroid Build Coastguard Worker
326*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
327*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 48387);
328*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
329*4bdc9457SAndroid Build Coastguard Worker // cr blocks
330*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
331*4bdc9457SAndroid Build Coastguard Worker // 48387 + 0 - (5 + 10 + 15) * 127 = 44577 = 0xAE21
332*4bdc9457SAndroid Build Coastguard Worker 0x21, 0xAE, 0, 0,
333*4bdc9457SAndroid Build Coastguard Worker // 48387 + 1 - (6 + 11 + 16) * 127 = 44197 = 0xACA5
334*4bdc9457SAndroid Build Coastguard Worker 0xA5, 0xAC, 0, 0,
335*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
336*4bdc9457SAndroid Build Coastguard Worker 5, 6, 10, 11, 15, 16,
337*4bdc9457SAndroid Build Coastguard Worker // bias again
338*4bdc9457SAndroid Build Coastguard Worker // 48387 + 2 - (7, 12, 17) * 127 = 43817 = 0xAB29
339*4bdc9457SAndroid Build Coastguard Worker 0x29, 0xAB, 0, 0,
340*4bdc9457SAndroid Build Coastguard Worker // 48387 + 3 - (8, 13, 18) * 127 = 43434 = 0xA9AD
341*4bdc9457SAndroid Build Coastguard Worker 0xAD, 0xA9, 0, 0,
342*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
343*4bdc9457SAndroid Build Coastguard Worker 7, 8, 12, 13, 17, 18,
344*4bdc9457SAndroid Build Coastguard Worker // bias again
345*4bdc9457SAndroid Build Coastguard Worker // 48387 + 4 - (9, 14, 19) * 127 = 43053 = 0xA831
346*4bdc9457SAndroid Build Coastguard Worker 0x31, 0xA8, 0, 0,
347*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
348*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
349*4bdc9457SAndroid Build Coastguard Worker 9, 0, 14, 0, 19, 0,
350*4bdc9457SAndroid Build Coastguard Worker };
351*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
352*4bdc9457SAndroid Build Coastguard Worker }
353*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_HWG_W,primary_tile_gt_kernel_size)354*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_HWG_W, primary_tile_gt_kernel_size) {
355*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
356*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
357*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
358*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
359*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
360*4bdc9457SAndroid Build Coastguard Worker
361*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
362*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
363*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
364*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
365*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
366*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
367*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
368*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
369*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
370*4bdc9457SAndroid Build Coastguard Worker
371*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
372*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
373*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
374*4bdc9457SAndroid Build Coastguard Worker };
375*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_hwg_w(
376*4bdc9457SAndroid Build Coastguard Worker primary_tile,
377*4bdc9457SAndroid Build Coastguard Worker h,
378*4bdc9457SAndroid Build Coastguard Worker w,
379*4bdc9457SAndroid Build Coastguard Worker c,
380*4bdc9457SAndroid Build Coastguard Worker cr,
381*4bdc9457SAndroid Build Coastguard Worker k.data(),
382*4bdc9457SAndroid Build Coastguard Worker b.data(),
383*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
384*4bdc9457SAndroid Build Coastguard Worker 0,
385*4bdc9457SAndroid Build Coastguard Worker ¶ms);
386*4bdc9457SAndroid Build Coastguard Worker
387*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
388*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 64516);
389*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
390*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
391*4bdc9457SAndroid Build Coastguard Worker // 64516 + 0 - (2 + 4 + 6 + 8) * 127 = 61976 = 0xF218
392*4bdc9457SAndroid Build Coastguard Worker 0x18, 0xF2, 0, 0,
393*4bdc9457SAndroid Build Coastguard Worker // 64516 + 1 - (3 + 5 + 7 + 9) * 127 = 61469 = 0xF01D
394*4bdc9457SAndroid Build Coastguard Worker 0x1D, 0xF0, 0, 0,
395*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
396*4bdc9457SAndroid Build Coastguard Worker 2, 3,
397*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
398*4bdc9457SAndroid Build Coastguard Worker 6, 7, 4, 5, 8, 9,
399*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
400*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
401*4bdc9457SAndroid Build Coastguard Worker };
402*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
403*4bdc9457SAndroid Build Coastguard Worker }
404*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QU8_DWCONV_HWG_W,primary_tile_gt_kernel_size_channels_gt_cr)405*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QU8_DWCONV_HWG_W, primary_tile_gt_kernel_size_channels_gt_cr) {
406*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
407*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
408*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
409*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
410*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
411*4bdc9457SAndroid Build Coastguard Worker
412*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
413*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
414*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> k(c * h * w); // k = [
415*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
416*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
417*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19,
418*4bdc9457SAndroid Build Coastguard Worker // 20, 21, 22, 23, 24]
419*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
420*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
421*4bdc9457SAndroid Build Coastguard Worker
422*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_packing_params params = {
423*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
424*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = 127,
425*4bdc9457SAndroid Build Coastguard Worker };
426*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_dwconv_hwg_w(
427*4bdc9457SAndroid Build Coastguard Worker primary_tile,
428*4bdc9457SAndroid Build Coastguard Worker h,
429*4bdc9457SAndroid Build Coastguard Worker w,
430*4bdc9457SAndroid Build Coastguard Worker c,
431*4bdc9457SAndroid Build Coastguard Worker cr,
432*4bdc9457SAndroid Build Coastguard Worker k.data(),
433*4bdc9457SAndroid Build Coastguard Worker b.data(),
434*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
435*4bdc9457SAndroid Build Coastguard Worker 0,
436*4bdc9457SAndroid Build Coastguard Worker ¶ms);
437*4bdc9457SAndroid Build Coastguard Worker
438*4bdc9457SAndroid Build Coastguard Worker const int32_t bias_offset = h * w * params.input_zero_point * params.kernel_zero_point;
439*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(bias_offset, 64516);
440*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
441*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
442*4bdc9457SAndroid Build Coastguard Worker // 64516 + 0 - (5 + 10 + 15 + 20) * 127 = 58166 = 0xE336
443*4bdc9457SAndroid Build Coastguard Worker 0x36, 0xE3, 0, 0,
444*4bdc9457SAndroid Build Coastguard Worker // 64516 + 1 - (6 + 11 + 16 + 21) * 127 = 57659 = 0xE13B
445*4bdc9457SAndroid Build Coastguard Worker 0x3B, 0xE1, 0, 0,
446*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
447*4bdc9457SAndroid Build Coastguard Worker 5, 6,
448*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
449*4bdc9457SAndroid Build Coastguard Worker 15, 16,
450*4bdc9457SAndroid Build Coastguard Worker 10, 11,
451*4bdc9457SAndroid Build Coastguard Worker 20, 21,
452*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
453*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
454*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
455*4bdc9457SAndroid Build Coastguard Worker // 64516 + 2 - (7 + 12 + 17 + 22) * 127 = 57152 = 0xDF40
456*4bdc9457SAndroid Build Coastguard Worker 0x40, 0xDF, 0, 0,
457*4bdc9457SAndroid Build Coastguard Worker // 64516 + 3 - (8 + 13 + 18 + 23) * 127 = 56645 = 0xDD45
458*4bdc9457SAndroid Build Coastguard Worker 0x45, 0xDD, 0, 0,
459*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
460*4bdc9457SAndroid Build Coastguard Worker 7, 8, 17, 18, 12, 13, 22, 23,
461*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
462*4bdc9457SAndroid Build Coastguard Worker // bias
463*4bdc9457SAndroid Build Coastguard Worker // 64516 + 4 - (9 + 14 + 19 + 24) * 127 = 56138 = 0xDB4A
464*4bdc9457SAndroid Build Coastguard Worker 0x4A, 0xDB, 0, 0,
465*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
466*4bdc9457SAndroid Build Coastguard Worker // weights
467*4bdc9457SAndroid Build Coastguard Worker 9, 0, 19, 0, 14, 0, 24, 0,
468*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
469*4bdc9457SAndroid Build Coastguard Worker };
470*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
471*4bdc9457SAndroid Build Coastguard Worker }
472*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_GHW_W,primary_tile_eq_kernel_size)473*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_GHW_W, primary_tile_eq_kernel_size) {
474*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
475*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
476*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
477*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
478*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
479*4bdc9457SAndroid Build Coastguard Worker
480*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
481*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
482*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
483*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
484*4bdc9457SAndroid Build Coastguard Worker
485*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
486*4bdc9457SAndroid Build Coastguard Worker
487*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
488*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
489*4bdc9457SAndroid Build Coastguard Worker };
490*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_ghw_w(
491*4bdc9457SAndroid Build Coastguard Worker primary_tile,
492*4bdc9457SAndroid Build Coastguard Worker h,
493*4bdc9457SAndroid Build Coastguard Worker w,
494*4bdc9457SAndroid Build Coastguard Worker c,
495*4bdc9457SAndroid Build Coastguard Worker cr,
496*4bdc9457SAndroid Build Coastguard Worker k.data(),
497*4bdc9457SAndroid Build Coastguard Worker b.data(),
498*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
499*4bdc9457SAndroid Build Coastguard Worker 0,
500*4bdc9457SAndroid Build Coastguard Worker ¶ms);
501*4bdc9457SAndroid Build Coastguard Worker
502*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
503*4bdc9457SAndroid Build Coastguard Worker // bias first
504*4bdc9457SAndroid Build Coastguard Worker // (2 + 3 + 4) * 127 = -1143 = 0xFFFFFB89
505*4bdc9457SAndroid Build Coastguard Worker 0x89, 0xFB, 0xFF, 0xFF,
506*4bdc9457SAndroid Build Coastguard Worker // (5 + 6 + 7) * 127 = -2285 = 0xFFFFF713
507*4bdc9457SAndroid Build Coastguard Worker 0x13, 0xF7, 0xFF, 0xFF,
508*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
509*4bdc9457SAndroid Build Coastguard Worker 2, 5,
510*4bdc9457SAndroid Build Coastguard Worker 3, 6,
511*4bdc9457SAndroid Build Coastguard Worker 4, 7,
512*4bdc9457SAndroid Build Coastguard Worker };
513*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
514*4bdc9457SAndroid Build Coastguard Worker }
515*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_GHW_W,primary_tile_eq_kernel_size_channels_gt_cr)516*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_GHW_W, primary_tile_eq_kernel_size_channels_gt_cr) {
517*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
518*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
519*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
520*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
521*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
522*4bdc9457SAndroid Build Coastguard Worker
523*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
524*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
525*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
526*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7,
527*4bdc9457SAndroid Build Coastguard Worker // 8, 9, 10,
528*4bdc9457SAndroid Build Coastguard Worker // 11, 12, 13,
529*4bdc9457SAndroid Build Coastguard Worker // 14, 15, 16,
530*4bdc9457SAndroid Build Coastguard Worker // 17, 18, 19]
531*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
532*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
533*4bdc9457SAndroid Build Coastguard Worker
534*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
535*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
536*4bdc9457SAndroid Build Coastguard Worker };
537*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_ghw_w(
538*4bdc9457SAndroid Build Coastguard Worker primary_tile,
539*4bdc9457SAndroid Build Coastguard Worker h,
540*4bdc9457SAndroid Build Coastguard Worker w,
541*4bdc9457SAndroid Build Coastguard Worker c,
542*4bdc9457SAndroid Build Coastguard Worker cr,
543*4bdc9457SAndroid Build Coastguard Worker k.data(),
544*4bdc9457SAndroid Build Coastguard Worker b.data(),
545*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
546*4bdc9457SAndroid Build Coastguard Worker 0,
547*4bdc9457SAndroid Build Coastguard Worker ¶ms);
548*4bdc9457SAndroid Build Coastguard Worker
549*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
550*4bdc9457SAndroid Build Coastguard Worker // cr blocks
551*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
552*4bdc9457SAndroid Build Coastguard Worker // 0 - (5 + 6 + 7) * 127 = -2286 = 0xFFFFF712
553*4bdc9457SAndroid Build Coastguard Worker 0x12, 0xF7, 0xFF, 0xFF,
554*4bdc9457SAndroid Build Coastguard Worker // 1 - (8 + 9 + 10) * 127 = -3428 = 0xFFFFF29C
555*4bdc9457SAndroid Build Coastguard Worker 0x9C, 0xF2, 0xFF, 0xFF,
556*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
557*4bdc9457SAndroid Build Coastguard Worker 5, 8, 6, 9, 7, 10,
558*4bdc9457SAndroid Build Coastguard Worker // bias again
559*4bdc9457SAndroid Build Coastguard Worker // 2 - (11 + 12 + 13) * 127 = -4570 = 0xFFFFEE26
560*4bdc9457SAndroid Build Coastguard Worker 0x26, 0xEE, 0xFF, 0xFF,
561*4bdc9457SAndroid Build Coastguard Worker // 3 - (14 + 15 + 16) * 127 = -5712 = 0xFFFFE9B0
562*4bdc9457SAndroid Build Coastguard Worker 0xB0, 0xE9, 0xFF, 0xFF,
563*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
564*4bdc9457SAndroid Build Coastguard Worker 11, 14, 12, 15, 13, 16,
565*4bdc9457SAndroid Build Coastguard Worker // bias again
566*4bdc9457SAndroid Build Coastguard Worker // 4 - (17 + 18 + 19) * 127 = -6854 = 0xFFFFE53A
567*4bdc9457SAndroid Build Coastguard Worker 0x3A, 0xE5, 0xFF, 0xFF,
568*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
569*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
570*4bdc9457SAndroid Build Coastguard Worker 17, 0, 18, 0, 19, 0,
571*4bdc9457SAndroid Build Coastguard Worker };
572*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
573*4bdc9457SAndroid Build Coastguard Worker }
574*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_GHW_W,primary_tile_gt_kernel_size)575*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_GHW_W, primary_tile_gt_kernel_size) {
576*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
577*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
578*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
579*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
580*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
581*4bdc9457SAndroid Build Coastguard Worker
582*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
583*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
584*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
585*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
586*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
587*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
588*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
589*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
590*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
591*4bdc9457SAndroid Build Coastguard Worker
592*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
593*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
594*4bdc9457SAndroid Build Coastguard Worker };
595*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_ghw_w(
596*4bdc9457SAndroid Build Coastguard Worker primary_tile,
597*4bdc9457SAndroid Build Coastguard Worker h,
598*4bdc9457SAndroid Build Coastguard Worker w,
599*4bdc9457SAndroid Build Coastguard Worker c,
600*4bdc9457SAndroid Build Coastguard Worker cr,
601*4bdc9457SAndroid Build Coastguard Worker k.data(),
602*4bdc9457SAndroid Build Coastguard Worker b.data(),
603*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
604*4bdc9457SAndroid Build Coastguard Worker 0,
605*4bdc9457SAndroid Build Coastguard Worker ¶ms);
606*4bdc9457SAndroid Build Coastguard Worker
607*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
608*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
609*4bdc9457SAndroid Build Coastguard Worker // 0 - (2 + 3 + 4 + 5) * 127 = -1778 = 0xFFFFF90E
610*4bdc9457SAndroid Build Coastguard Worker 0x0E, 0xF9, 0xFF, 0xFF,
611*4bdc9457SAndroid Build Coastguard Worker // 1 - (6 + 7 + 8 + 9) * 127 = -3809 = 0xFFFFF11F
612*4bdc9457SAndroid Build Coastguard Worker 0x1F, 0xF1, 0xFF, 0xFF,
613*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
614*4bdc9457SAndroid Build Coastguard Worker 2, 6,
615*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
616*4bdc9457SAndroid Build Coastguard Worker 4, 8, 3, 7, 5, 9,
617*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
618*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
619*4bdc9457SAndroid Build Coastguard Worker };
620*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
621*4bdc9457SAndroid Build Coastguard Worker }
622*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_GHW_W,primary_tile_gt_kernel_size_channels_gt_cr)623*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_GHW_W, primary_tile_gt_kernel_size_channels_gt_cr) {
624*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
625*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
626*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
627*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
628*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
629*4bdc9457SAndroid Build Coastguard Worker
630*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
631*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
632*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
633*4bdc9457SAndroid Build Coastguard Worker // 5, 6,
634*4bdc9457SAndroid Build Coastguard Worker // 7, 8,
635*4bdc9457SAndroid Build Coastguard Worker // 9, 10,
636*4bdc9457SAndroid Build Coastguard Worker // 11, 12,
637*4bdc9457SAndroid Build Coastguard Worker // 13, 14,
638*4bdc9457SAndroid Build Coastguard Worker // 15, 16,
639*4bdc9457SAndroid Build Coastguard Worker // 17, 18,
640*4bdc9457SAndroid Build Coastguard Worker // 19, 20,
641*4bdc9457SAndroid Build Coastguard Worker // 21, 22,
642*4bdc9457SAndroid Build Coastguard Worker // 23, 24]
643*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
644*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
645*4bdc9457SAndroid Build Coastguard Worker
646*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
647*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
648*4bdc9457SAndroid Build Coastguard Worker };
649*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_ghw_w(
650*4bdc9457SAndroid Build Coastguard Worker primary_tile,
651*4bdc9457SAndroid Build Coastguard Worker h,
652*4bdc9457SAndroid Build Coastguard Worker w,
653*4bdc9457SAndroid Build Coastguard Worker c,
654*4bdc9457SAndroid Build Coastguard Worker cr,
655*4bdc9457SAndroid Build Coastguard Worker k.data(),
656*4bdc9457SAndroid Build Coastguard Worker b.data(),
657*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
658*4bdc9457SAndroid Build Coastguard Worker 0,
659*4bdc9457SAndroid Build Coastguard Worker ¶ms);
660*4bdc9457SAndroid Build Coastguard Worker
661*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
662*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
663*4bdc9457SAndroid Build Coastguard Worker // 0 - (5 + 6 + 7 + 8) * 127 = -3302 = 0xFFFFF31A
664*4bdc9457SAndroid Build Coastguard Worker 0x1A, 0xF3, 0xFF, 0xFF,
665*4bdc9457SAndroid Build Coastguard Worker // 1 - (9 + 10 + 11 + 12) * 127 = -5333 = 0xFFFFEB2B
666*4bdc9457SAndroid Build Coastguard Worker 0x2B, 0xEB, 0xFF, 0xFF,
667*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
668*4bdc9457SAndroid Build Coastguard Worker 5, 9,
669*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
670*4bdc9457SAndroid Build Coastguard Worker 7, 11,
671*4bdc9457SAndroid Build Coastguard Worker 6, 10,
672*4bdc9457SAndroid Build Coastguard Worker 8, 12,
673*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
674*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
675*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
676*4bdc9457SAndroid Build Coastguard Worker // 2 - (13 + 14 + 15 + 16) * 127 = -7364 = 0xFFFFE33C
677*4bdc9457SAndroid Build Coastguard Worker 0x3C, 0xE3, 0xFF, 0xFF,
678*4bdc9457SAndroid Build Coastguard Worker // 3 - (17 + 18 + 19 + 20) * 127 = -9395 = 0xFFFFDB4D
679*4bdc9457SAndroid Build Coastguard Worker 0x4D, 0xDB, 0xFF, 0xFF,
680*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
681*4bdc9457SAndroid Build Coastguard Worker 13, 17, 15, 19, 14, 18, 16, 20,
682*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
683*4bdc9457SAndroid Build Coastguard Worker // bias
684*4bdc9457SAndroid Build Coastguard Worker // 4 - (21 + 22 + 23 + 24) * 127 = -11426 = 0xFFFFD35E
685*4bdc9457SAndroid Build Coastguard Worker 0x5E, 0xD3, 0xFF, 0xFF,
686*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
687*4bdc9457SAndroid Build Coastguard Worker // weights
688*4bdc9457SAndroid Build Coastguard Worker 21, 0, 23, 0, 22, 0, 24, 0,
689*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
690*4bdc9457SAndroid Build Coastguard Worker };
691*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
692*4bdc9457SAndroid Build Coastguard Worker }
693*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_HWG_W,primary_tile_eq_kernel_size)694*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_HWG_W, primary_tile_eq_kernel_size) {
695*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
696*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
697*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
698*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
699*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
700*4bdc9457SAndroid Build Coastguard Worker
701*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
702*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
703*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
704*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
705*4bdc9457SAndroid Build Coastguard Worker
706*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
707*4bdc9457SAndroid Build Coastguard Worker
708*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
709*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
710*4bdc9457SAndroid Build Coastguard Worker };
711*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_hwg_w(
712*4bdc9457SAndroid Build Coastguard Worker primary_tile,
713*4bdc9457SAndroid Build Coastguard Worker h,
714*4bdc9457SAndroid Build Coastguard Worker w,
715*4bdc9457SAndroid Build Coastguard Worker c,
716*4bdc9457SAndroid Build Coastguard Worker cr,
717*4bdc9457SAndroid Build Coastguard Worker k.data(),
718*4bdc9457SAndroid Build Coastguard Worker b.data(),
719*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
720*4bdc9457SAndroid Build Coastguard Worker 0,
721*4bdc9457SAndroid Build Coastguard Worker ¶ms);
722*4bdc9457SAndroid Build Coastguard Worker
723*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
724*4bdc9457SAndroid Build Coastguard Worker // bias first
725*4bdc9457SAndroid Build Coastguard Worker // 0 - (2 + 4 + 6) * 127 = -1524 = 0xFFFFFA0C
726*4bdc9457SAndroid Build Coastguard Worker 0x0C, 0xFA, 0xFF, 0xFF,
727*4bdc9457SAndroid Build Coastguard Worker // 1 - (3 + 5 + 7) * 127 = -1904 = 0xFFFFF890
728*4bdc9457SAndroid Build Coastguard Worker 0x90, 0xF8, 0xFF, 0xFF,
729*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
730*4bdc9457SAndroid Build Coastguard Worker 2, 3,
731*4bdc9457SAndroid Build Coastguard Worker 4, 5,
732*4bdc9457SAndroid Build Coastguard Worker 6, 7,
733*4bdc9457SAndroid Build Coastguard Worker };
734*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
735*4bdc9457SAndroid Build Coastguard Worker }
736*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_HWG_W,primary_tile_eq_kernel_size_channels_gt_cr)737*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_HWG_W, primary_tile_eq_kernel_size_channels_gt_cr) {
738*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
739*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
740*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
741*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
742*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
743*4bdc9457SAndroid Build Coastguard Worker
744*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
745*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
746*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
747*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
748*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
749*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19]
750*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
751*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
752*4bdc9457SAndroid Build Coastguard Worker
753*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
754*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
755*4bdc9457SAndroid Build Coastguard Worker };
756*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_hwg_w(
757*4bdc9457SAndroid Build Coastguard Worker primary_tile,
758*4bdc9457SAndroid Build Coastguard Worker h,
759*4bdc9457SAndroid Build Coastguard Worker w,
760*4bdc9457SAndroid Build Coastguard Worker c,
761*4bdc9457SAndroid Build Coastguard Worker cr,
762*4bdc9457SAndroid Build Coastguard Worker k.data(),
763*4bdc9457SAndroid Build Coastguard Worker b.data(),
764*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
765*4bdc9457SAndroid Build Coastguard Worker 0,
766*4bdc9457SAndroid Build Coastguard Worker ¶ms);
767*4bdc9457SAndroid Build Coastguard Worker
768*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
769*4bdc9457SAndroid Build Coastguard Worker // cr blocks
770*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
771*4bdc9457SAndroid Build Coastguard Worker // 0 - (5 + 10 + 15) * 127 = -3810 = 0xFFFFF11E
772*4bdc9457SAndroid Build Coastguard Worker 0x1E, 0xF1, 0xFF, 0xFF,
773*4bdc9457SAndroid Build Coastguard Worker // 1 - (6 + 11 + 16) * 127 = -4190 = 0xFFFFEFA2
774*4bdc9457SAndroid Build Coastguard Worker 0xA2, 0xEF, 0xFF, 0xFF,
775*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
776*4bdc9457SAndroid Build Coastguard Worker 5, 6, 10, 11, 15, 16,
777*4bdc9457SAndroid Build Coastguard Worker // bias again
778*4bdc9457SAndroid Build Coastguard Worker // 2 - (7, 12, 17) * 127 = -45709 = 0xFFFFEE26
779*4bdc9457SAndroid Build Coastguard Worker 0x26, 0xEE, 0xFF, 0xFF,
780*4bdc9457SAndroid Build Coastguard Worker // 3 - (8, 13, 18) * 127 = -4950 = 0xFFFFECAA
781*4bdc9457SAndroid Build Coastguard Worker 0xAA, 0xEC, 0xFF, 0xFF,
782*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
783*4bdc9457SAndroid Build Coastguard Worker 7, 8, 12, 13, 17, 18,
784*4bdc9457SAndroid Build Coastguard Worker // bias again
785*4bdc9457SAndroid Build Coastguard Worker // 4 - (9, 14, 19) * 127 = -5330 = 0xFFFFEB2E
786*4bdc9457SAndroid Build Coastguard Worker 0x2E, 0xEB, 0xFF, 0xFF,
787*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
788*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
789*4bdc9457SAndroid Build Coastguard Worker 9, 0, 14, 0, 19, 0,
790*4bdc9457SAndroid Build Coastguard Worker };
791*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
792*4bdc9457SAndroid Build Coastguard Worker }
793*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_HWG_W,primary_tile_gt_kernel_size)794*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_HWG_W, primary_tile_gt_kernel_size) {
795*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
796*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
797*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
798*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
799*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
800*4bdc9457SAndroid Build Coastguard Worker
801*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
802*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
803*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
804*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
805*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
806*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
807*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
808*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
809*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
810*4bdc9457SAndroid Build Coastguard Worker
811*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
812*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
813*4bdc9457SAndroid Build Coastguard Worker };
814*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_hwg_w(
815*4bdc9457SAndroid Build Coastguard Worker primary_tile,
816*4bdc9457SAndroid Build Coastguard Worker h,
817*4bdc9457SAndroid Build Coastguard Worker w,
818*4bdc9457SAndroid Build Coastguard Worker c,
819*4bdc9457SAndroid Build Coastguard Worker cr,
820*4bdc9457SAndroid Build Coastguard Worker k.data(),
821*4bdc9457SAndroid Build Coastguard Worker b.data(),
822*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
823*4bdc9457SAndroid Build Coastguard Worker 0,
824*4bdc9457SAndroid Build Coastguard Worker ¶ms);
825*4bdc9457SAndroid Build Coastguard Worker
826*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
827*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
828*4bdc9457SAndroid Build Coastguard Worker // 0 - (2 + 4 + 6 + 8) * 127 = -2540 = 0xFFFFF614
829*4bdc9457SAndroid Build Coastguard Worker 0x14, 0xF6, 0xFF, 0xFF,
830*4bdc9457SAndroid Build Coastguard Worker // 1 - (3 + 5 + 7 + 9) * 127 = -3047 = 0xFFFFF419
831*4bdc9457SAndroid Build Coastguard Worker 0x19, 0xF4, 0xFF, 0xFF,
832*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
833*4bdc9457SAndroid Build Coastguard Worker 2, 3,
834*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
835*4bdc9457SAndroid Build Coastguard Worker 6, 7, 4, 5, 8, 9,
836*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
837*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
838*4bdc9457SAndroid Build Coastguard Worker };
839*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
840*4bdc9457SAndroid Build Coastguard Worker }
841*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_QS8_DWCONV_HWG_W,primary_tile_gt_kernel_size_channels_gt_cr)842*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_QS8_DWCONV_HWG_W, primary_tile_gt_kernel_size_channels_gt_cr) {
843*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
844*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
845*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
846*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
847*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
848*4bdc9457SAndroid Build Coastguard Worker
849*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> b(c);
850*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
851*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> k(c * h * w); // k = [
852*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
853*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
854*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19,
855*4bdc9457SAndroid Build Coastguard Worker // 20, 21, 22, 23, 24]
856*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
857*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> packed_weights(((primary_tile + sizeof(int32_t)/sizeof(uint8_t)) * round_up_po2(c, cr)));
858*4bdc9457SAndroid Build Coastguard Worker
859*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_packing_params params = {
860*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = 127,
861*4bdc9457SAndroid Build Coastguard Worker };
862*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_dwconv_hwg_w(
863*4bdc9457SAndroid Build Coastguard Worker primary_tile,
864*4bdc9457SAndroid Build Coastguard Worker h,
865*4bdc9457SAndroid Build Coastguard Worker w,
866*4bdc9457SAndroid Build Coastguard Worker c,
867*4bdc9457SAndroid Build Coastguard Worker cr,
868*4bdc9457SAndroid Build Coastguard Worker k.data(),
869*4bdc9457SAndroid Build Coastguard Worker b.data(),
870*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
871*4bdc9457SAndroid Build Coastguard Worker 0,
872*4bdc9457SAndroid Build Coastguard Worker ¶ms);
873*4bdc9457SAndroid Build Coastguard Worker
874*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> expected = {
875*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
876*4bdc9457SAndroid Build Coastguard Worker // 0 - (5 + 10 + 15 + 20) * 127 = -6350 = 0xFFFFE732
877*4bdc9457SAndroid Build Coastguard Worker 0x32, 0xE7, 0xFF, 0xFF,
878*4bdc9457SAndroid Build Coastguard Worker // 1 - (6 + 11 + 16 + 21) * 127 = -6857 = 0xFFFFE537
879*4bdc9457SAndroid Build Coastguard Worker 0x37, 0xE5, 0xFF, 0xFF,
880*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
881*4bdc9457SAndroid Build Coastguard Worker 5, 6,
882*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
883*4bdc9457SAndroid Build Coastguard Worker 15, 16,
884*4bdc9457SAndroid Build Coastguard Worker 10, 11,
885*4bdc9457SAndroid Build Coastguard Worker 20, 21,
886*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
887*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
888*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
889*4bdc9457SAndroid Build Coastguard Worker // 2 - (7 + 12 + 17 + 22) * 127 = -7364 = 0xFFFFE33C
890*4bdc9457SAndroid Build Coastguard Worker 0x3C, 0xE3, 0xFF, 0xFF,
891*4bdc9457SAndroid Build Coastguard Worker // 3 - (8 + 13 + 18 + 23) * 127 = -7871 = 0xFFFFE141
892*4bdc9457SAndroid Build Coastguard Worker 0x41, 0xE1, 0xFF, 0xFF,
893*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
894*4bdc9457SAndroid Build Coastguard Worker 7, 8, 17, 18, 12, 13, 22, 23,
895*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
896*4bdc9457SAndroid Build Coastguard Worker // bias
897*4bdc9457SAndroid Build Coastguard Worker // 4 - (9 + 14 + 19 + 24) * 127 = -8378 = 0xFFFFDF46
898*4bdc9457SAndroid Build Coastguard Worker 0x46, 0xDF, 0xFF, 0xFF,
899*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0,
900*4bdc9457SAndroid Build Coastguard Worker // weights
901*4bdc9457SAndroid Build Coastguard Worker 9, 0, 19, 0, 14, 0, 24, 0,
902*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
903*4bdc9457SAndroid Build Coastguard Worker };
904*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
905*4bdc9457SAndroid Build Coastguard Worker }
906*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_GHW_W,primary_tile_eq_kernel_size)907*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_GHW_W, primary_tile_eq_kernel_size) {
908*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
909*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
910*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
911*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
912*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
913*4bdc9457SAndroid Build Coastguard Worker
914*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
915*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
916*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
917*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
918*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
919*4bdc9457SAndroid Build Coastguard Worker
920*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_ghw_w(
921*4bdc9457SAndroid Build Coastguard Worker primary_tile,
922*4bdc9457SAndroid Build Coastguard Worker h,
923*4bdc9457SAndroid Build Coastguard Worker w,
924*4bdc9457SAndroid Build Coastguard Worker c,
925*4bdc9457SAndroid Build Coastguard Worker cr,
926*4bdc9457SAndroid Build Coastguard Worker k.data(),
927*4bdc9457SAndroid Build Coastguard Worker b.data(),
928*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
929*4bdc9457SAndroid Build Coastguard Worker 0,
930*4bdc9457SAndroid Build Coastguard Worker nullptr);
931*4bdc9457SAndroid Build Coastguard Worker
932*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
933*4bdc9457SAndroid Build Coastguard Worker // bias first
934*4bdc9457SAndroid Build Coastguard Worker 0, 1,
935*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
936*4bdc9457SAndroid Build Coastguard Worker 2, 5,
937*4bdc9457SAndroid Build Coastguard Worker 3, 6,
938*4bdc9457SAndroid Build Coastguard Worker 4, 7,
939*4bdc9457SAndroid Build Coastguard Worker };
940*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
941*4bdc9457SAndroid Build Coastguard Worker }
942*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_GHW_W,primary_tile_eq_kernel_size_channels_gt_cr)943*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_GHW_W, primary_tile_eq_kernel_size_channels_gt_cr) {
944*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
945*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
946*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
947*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
948*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
949*4bdc9457SAndroid Build Coastguard Worker
950*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
951*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
952*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
953*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7,
954*4bdc9457SAndroid Build Coastguard Worker // 8, 9, 10,
955*4bdc9457SAndroid Build Coastguard Worker // 11, 12, 13,
956*4bdc9457SAndroid Build Coastguard Worker // 14, 15, 16,
957*4bdc9457SAndroid Build Coastguard Worker // 17, 18, 19]
958*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
959*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
960*4bdc9457SAndroid Build Coastguard Worker
961*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_ghw_w(
962*4bdc9457SAndroid Build Coastguard Worker primary_tile,
963*4bdc9457SAndroid Build Coastguard Worker h,
964*4bdc9457SAndroid Build Coastguard Worker w,
965*4bdc9457SAndroid Build Coastguard Worker c,
966*4bdc9457SAndroid Build Coastguard Worker cr,
967*4bdc9457SAndroid Build Coastguard Worker k.data(),
968*4bdc9457SAndroid Build Coastguard Worker b.data(),
969*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
970*4bdc9457SAndroid Build Coastguard Worker 0,
971*4bdc9457SAndroid Build Coastguard Worker nullptr);
972*4bdc9457SAndroid Build Coastguard Worker
973*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
974*4bdc9457SAndroid Build Coastguard Worker // cr blocks
975*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
976*4bdc9457SAndroid Build Coastguard Worker 0, 1,
977*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
978*4bdc9457SAndroid Build Coastguard Worker 5, 8, 6, 9, 7, 10,
979*4bdc9457SAndroid Build Coastguard Worker // bias again
980*4bdc9457SAndroid Build Coastguard Worker 2, 3,
981*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
982*4bdc9457SAndroid Build Coastguard Worker 11, 14, 12, 15, 13, 16,
983*4bdc9457SAndroid Build Coastguard Worker // bias again
984*4bdc9457SAndroid Build Coastguard Worker 4, 0,
985*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
986*4bdc9457SAndroid Build Coastguard Worker 17, 0, 18, 0, 19, 0,
987*4bdc9457SAndroid Build Coastguard Worker };
988*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
989*4bdc9457SAndroid Build Coastguard Worker }
990*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_GHW_W,primary_tile_gt_kernel_size)991*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_GHW_W, primary_tile_gt_kernel_size) {
992*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
993*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
994*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
995*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
996*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
997*4bdc9457SAndroid Build Coastguard Worker
998*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
999*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1000*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
1001*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1002*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1003*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1004*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1005*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1006*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1007*4bdc9457SAndroid Build Coastguard Worker
1008*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_ghw_w(
1009*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1010*4bdc9457SAndroid Build Coastguard Worker h,
1011*4bdc9457SAndroid Build Coastguard Worker w,
1012*4bdc9457SAndroid Build Coastguard Worker c,
1013*4bdc9457SAndroid Build Coastguard Worker cr,
1014*4bdc9457SAndroid Build Coastguard Worker k.data(),
1015*4bdc9457SAndroid Build Coastguard Worker b.data(),
1016*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1017*4bdc9457SAndroid Build Coastguard Worker 0,
1018*4bdc9457SAndroid Build Coastguard Worker nullptr);
1019*4bdc9457SAndroid Build Coastguard Worker
1020*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1021*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1022*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1023*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1024*4bdc9457SAndroid Build Coastguard Worker 2, 6,
1025*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1026*4bdc9457SAndroid Build Coastguard Worker 4, 8, 3, 7, 5, 9,
1027*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1028*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1029*4bdc9457SAndroid Build Coastguard Worker };
1030*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1031*4bdc9457SAndroid Build Coastguard Worker }
1032*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_GHW_W,primary_tile_gt_kernel_size_channels_gt_cr)1033*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_GHW_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1034*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1035*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1036*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1037*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1038*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1039*4bdc9457SAndroid Build Coastguard Worker
1040*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
1041*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1042*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
1043*4bdc9457SAndroid Build Coastguard Worker // 5, 6,
1044*4bdc9457SAndroid Build Coastguard Worker // 7, 8,
1045*4bdc9457SAndroid Build Coastguard Worker // 9, 10,
1046*4bdc9457SAndroid Build Coastguard Worker // 11, 12,
1047*4bdc9457SAndroid Build Coastguard Worker // 13, 14,
1048*4bdc9457SAndroid Build Coastguard Worker // 15, 16,
1049*4bdc9457SAndroid Build Coastguard Worker // 17, 18,
1050*4bdc9457SAndroid Build Coastguard Worker // 19, 20,
1051*4bdc9457SAndroid Build Coastguard Worker // 21, 22,
1052*4bdc9457SAndroid Build Coastguard Worker // 23, 24]
1053*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1054*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1055*4bdc9457SAndroid Build Coastguard Worker
1056*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_ghw_w(
1057*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1058*4bdc9457SAndroid Build Coastguard Worker h,
1059*4bdc9457SAndroid Build Coastguard Worker w,
1060*4bdc9457SAndroid Build Coastguard Worker c,
1061*4bdc9457SAndroid Build Coastguard Worker cr,
1062*4bdc9457SAndroid Build Coastguard Worker k.data(),
1063*4bdc9457SAndroid Build Coastguard Worker b.data(),
1064*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1065*4bdc9457SAndroid Build Coastguard Worker 0,
1066*4bdc9457SAndroid Build Coastguard Worker nullptr);
1067*4bdc9457SAndroid Build Coastguard Worker
1068*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1069*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1070*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1071*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1072*4bdc9457SAndroid Build Coastguard Worker 5, 9,
1073*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1074*4bdc9457SAndroid Build Coastguard Worker 7, 11,
1075*4bdc9457SAndroid Build Coastguard Worker 6, 10,
1076*4bdc9457SAndroid Build Coastguard Worker 8, 12,
1077*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1078*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1079*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1080*4bdc9457SAndroid Build Coastguard Worker 2, 3,
1081*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1082*4bdc9457SAndroid Build Coastguard Worker 13, 17, 15, 19, 14, 18, 16, 20,
1083*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1084*4bdc9457SAndroid Build Coastguard Worker // bias
1085*4bdc9457SAndroid Build Coastguard Worker 4, 0,
1086*4bdc9457SAndroid Build Coastguard Worker // weights
1087*4bdc9457SAndroid Build Coastguard Worker 21, 0, 23, 0, 22, 0, 24, 0,
1088*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1089*4bdc9457SAndroid Build Coastguard Worker };
1090*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1091*4bdc9457SAndroid Build Coastguard Worker }
1092*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_HWG_W,primary_tile_eq_kernel_size)1093*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_HWG_W, primary_tile_eq_kernel_size) {
1094*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1095*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1096*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1097*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1098*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1099*4bdc9457SAndroid Build Coastguard Worker
1100*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
1101*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1102*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
1103*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1104*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1105*4bdc9457SAndroid Build Coastguard Worker
1106*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_hwg_w(
1107*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1108*4bdc9457SAndroid Build Coastguard Worker h,
1109*4bdc9457SAndroid Build Coastguard Worker w,
1110*4bdc9457SAndroid Build Coastguard Worker c,
1111*4bdc9457SAndroid Build Coastguard Worker cr,
1112*4bdc9457SAndroid Build Coastguard Worker k.data(),
1113*4bdc9457SAndroid Build Coastguard Worker b.data(),
1114*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1115*4bdc9457SAndroid Build Coastguard Worker 0,
1116*4bdc9457SAndroid Build Coastguard Worker nullptr);
1117*4bdc9457SAndroid Build Coastguard Worker
1118*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1119*4bdc9457SAndroid Build Coastguard Worker // bias first
1120*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1121*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1122*4bdc9457SAndroid Build Coastguard Worker 2, 3,
1123*4bdc9457SAndroid Build Coastguard Worker 4, 5,
1124*4bdc9457SAndroid Build Coastguard Worker 6, 7,
1125*4bdc9457SAndroid Build Coastguard Worker };
1126*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1127*4bdc9457SAndroid Build Coastguard Worker }
1128*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_HWG_W,primary_tile_eq_kernel_size_channels_gt_cr)1129*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_HWG_W, primary_tile_eq_kernel_size_channels_gt_cr) {
1130*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1131*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1132*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1133*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1134*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1135*4bdc9457SAndroid Build Coastguard Worker
1136*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
1137*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1138*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
1139*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1140*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1141*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19]
1142*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1143*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1144*4bdc9457SAndroid Build Coastguard Worker
1145*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_hwg_w(
1146*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1147*4bdc9457SAndroid Build Coastguard Worker h,
1148*4bdc9457SAndroid Build Coastguard Worker w,
1149*4bdc9457SAndroid Build Coastguard Worker c,
1150*4bdc9457SAndroid Build Coastguard Worker cr,
1151*4bdc9457SAndroid Build Coastguard Worker k.data(),
1152*4bdc9457SAndroid Build Coastguard Worker b.data(),
1153*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1154*4bdc9457SAndroid Build Coastguard Worker 0,
1155*4bdc9457SAndroid Build Coastguard Worker nullptr);
1156*4bdc9457SAndroid Build Coastguard Worker
1157*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1158*4bdc9457SAndroid Build Coastguard Worker // cr blocks
1159*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1160*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1161*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1162*4bdc9457SAndroid Build Coastguard Worker 5, 6, 10, 11, 15, 16,
1163*4bdc9457SAndroid Build Coastguard Worker // bias again
1164*4bdc9457SAndroid Build Coastguard Worker 2, 3,
1165*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1166*4bdc9457SAndroid Build Coastguard Worker 7, 8, 12, 13, 17, 18,
1167*4bdc9457SAndroid Build Coastguard Worker // bias again
1168*4bdc9457SAndroid Build Coastguard Worker 4, 0,
1169*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1170*4bdc9457SAndroid Build Coastguard Worker 9, 0, 14, 0, 19, 0,
1171*4bdc9457SAndroid Build Coastguard Worker };
1172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1173*4bdc9457SAndroid Build Coastguard Worker }
1174*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_HWG_W,primary_tile_gt_kernel_size)1175*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_HWG_W, primary_tile_gt_kernel_size) {
1176*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1177*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1178*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1179*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1180*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1181*4bdc9457SAndroid Build Coastguard Worker
1182*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
1183*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1184*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
1185*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1186*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1187*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1188*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1189*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1190*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1191*4bdc9457SAndroid Build Coastguard Worker
1192*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_hwg_w(
1193*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1194*4bdc9457SAndroid Build Coastguard Worker h,
1195*4bdc9457SAndroid Build Coastguard Worker w,
1196*4bdc9457SAndroid Build Coastguard Worker c,
1197*4bdc9457SAndroid Build Coastguard Worker cr,
1198*4bdc9457SAndroid Build Coastguard Worker k.data(),
1199*4bdc9457SAndroid Build Coastguard Worker b.data(),
1200*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1201*4bdc9457SAndroid Build Coastguard Worker 0,
1202*4bdc9457SAndroid Build Coastguard Worker nullptr);
1203*4bdc9457SAndroid Build Coastguard Worker
1204*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1205*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1206*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1207*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1208*4bdc9457SAndroid Build Coastguard Worker 2, 3,
1209*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1210*4bdc9457SAndroid Build Coastguard Worker 6, 7, 4, 5, 8, 9,
1211*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1212*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1213*4bdc9457SAndroid Build Coastguard Worker };
1214*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1215*4bdc9457SAndroid Build Coastguard Worker }
1216*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F16_DWCONV_HWG_W,primary_tile_gt_kernel_size_channels_gt_cr)1217*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F16_DWCONV_HWG_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1218*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1219*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1220*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1221*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1222*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1223*4bdc9457SAndroid Build Coastguard Worker
1224*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(c);
1225*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1226*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> k(c * h * w); // k = [
1227*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1228*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1229*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19,
1230*4bdc9457SAndroid Build Coastguard Worker // 20, 21, 22, 23, 24]
1231*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1232*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1233*4bdc9457SAndroid Build Coastguard Worker
1234*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_dwconv_hwg_w(
1235*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1236*4bdc9457SAndroid Build Coastguard Worker h,
1237*4bdc9457SAndroid Build Coastguard Worker w,
1238*4bdc9457SAndroid Build Coastguard Worker c,
1239*4bdc9457SAndroid Build Coastguard Worker cr,
1240*4bdc9457SAndroid Build Coastguard Worker k.data(),
1241*4bdc9457SAndroid Build Coastguard Worker b.data(),
1242*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1243*4bdc9457SAndroid Build Coastguard Worker 0,
1244*4bdc9457SAndroid Build Coastguard Worker nullptr);
1245*4bdc9457SAndroid Build Coastguard Worker
1246*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected = {
1247*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1248*4bdc9457SAndroid Build Coastguard Worker 0, 1,
1249*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1250*4bdc9457SAndroid Build Coastguard Worker 5, 6,
1251*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1252*4bdc9457SAndroid Build Coastguard Worker 15, 16,
1253*4bdc9457SAndroid Build Coastguard Worker 10, 11,
1254*4bdc9457SAndroid Build Coastguard Worker 20, 21,
1255*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1256*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1257*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1258*4bdc9457SAndroid Build Coastguard Worker 2, 3,
1259*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1260*4bdc9457SAndroid Build Coastguard Worker 7, 8, 17, 18, 12, 13, 22, 23,
1261*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1262*4bdc9457SAndroid Build Coastguard Worker // bias
1263*4bdc9457SAndroid Build Coastguard Worker 4, 0,
1264*4bdc9457SAndroid Build Coastguard Worker // weights
1265*4bdc9457SAndroid Build Coastguard Worker 9, 0, 19, 0, 14, 0, 24, 0,
1266*4bdc9457SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1267*4bdc9457SAndroid Build Coastguard Worker };
1268*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1269*4bdc9457SAndroid Build Coastguard Worker }
1270*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_GHW_W,primary_tile_eq_kernel_size)1271*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_GHW_W, primary_tile_eq_kernel_size) {
1272*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1273*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1274*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1275*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1276*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1277*4bdc9457SAndroid Build Coastguard Worker
1278*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1279*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1280*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
1281*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1282*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1283*4bdc9457SAndroid Build Coastguard Worker
1284*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_ghw_w(
1285*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1286*4bdc9457SAndroid Build Coastguard Worker h,
1287*4bdc9457SAndroid Build Coastguard Worker w,
1288*4bdc9457SAndroid Build Coastguard Worker c,
1289*4bdc9457SAndroid Build Coastguard Worker cr,
1290*4bdc9457SAndroid Build Coastguard Worker k.data(),
1291*4bdc9457SAndroid Build Coastguard Worker b.data(),
1292*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1293*4bdc9457SAndroid Build Coastguard Worker 0,
1294*4bdc9457SAndroid Build Coastguard Worker nullptr);
1295*4bdc9457SAndroid Build Coastguard Worker
1296*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1297*4bdc9457SAndroid Build Coastguard Worker // bias first
1298*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1299*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1300*4bdc9457SAndroid Build Coastguard Worker 2.0f, 5.0f,
1301*4bdc9457SAndroid Build Coastguard Worker 3.0f, 6.0f,
1302*4bdc9457SAndroid Build Coastguard Worker 4.0f, 7.0f,
1303*4bdc9457SAndroid Build Coastguard Worker };
1304*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1305*4bdc9457SAndroid Build Coastguard Worker }
1306*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_GHW_W,primary_tile_eq_kernel_size_channels_gt_cr)1307*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_GHW_W, primary_tile_eq_kernel_size_channels_gt_cr) {
1308*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1309*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1310*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1311*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1312*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1313*4bdc9457SAndroid Build Coastguard Worker
1314*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1315*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1316*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1317*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7,
1318*4bdc9457SAndroid Build Coastguard Worker // 8, 9, 10,
1319*4bdc9457SAndroid Build Coastguard Worker // 11, 12, 13,
1320*4bdc9457SAndroid Build Coastguard Worker // 14, 15, 16,
1321*4bdc9457SAndroid Build Coastguard Worker // 17, 18, 19]
1322*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1323*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1324*4bdc9457SAndroid Build Coastguard Worker
1325*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_ghw_w(
1326*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1327*4bdc9457SAndroid Build Coastguard Worker h,
1328*4bdc9457SAndroid Build Coastguard Worker w,
1329*4bdc9457SAndroid Build Coastguard Worker c,
1330*4bdc9457SAndroid Build Coastguard Worker cr,
1331*4bdc9457SAndroid Build Coastguard Worker k.data(),
1332*4bdc9457SAndroid Build Coastguard Worker b.data(),
1333*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1334*4bdc9457SAndroid Build Coastguard Worker 0,
1335*4bdc9457SAndroid Build Coastguard Worker nullptr);
1336*4bdc9457SAndroid Build Coastguard Worker
1337*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1338*4bdc9457SAndroid Build Coastguard Worker // cr blocks
1339*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1340*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1341*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1342*4bdc9457SAndroid Build Coastguard Worker 5.0f, 8.0f, 6.0f, 9.0f, 7.0f, 10.0f,
1343*4bdc9457SAndroid Build Coastguard Worker // bias again
1344*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1345*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1346*4bdc9457SAndroid Build Coastguard Worker 11.0f, 14.0f, 12.0f, 15.0f, 13.0f, 16.0f,
1347*4bdc9457SAndroid Build Coastguard Worker // bias again
1348*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1349*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1350*4bdc9457SAndroid Build Coastguard Worker 17.0f, 0.0f, 18.0f, 0.0f, 19.0f, 0.0f,
1351*4bdc9457SAndroid Build Coastguard Worker };
1352*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1353*4bdc9457SAndroid Build Coastguard Worker }
1354*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_GHW_W,primary_tile_gt_kernel_size)1355*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_GHW_W, primary_tile_gt_kernel_size) {
1356*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1357*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1358*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1359*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1360*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1361*4bdc9457SAndroid Build Coastguard Worker
1362*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1363*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1364*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1365*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1366*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1367*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1368*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1369*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1370*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1371*4bdc9457SAndroid Build Coastguard Worker
1372*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_ghw_w(
1373*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1374*4bdc9457SAndroid Build Coastguard Worker h,
1375*4bdc9457SAndroid Build Coastguard Worker w,
1376*4bdc9457SAndroid Build Coastguard Worker c,
1377*4bdc9457SAndroid Build Coastguard Worker cr,
1378*4bdc9457SAndroid Build Coastguard Worker k.data(),
1379*4bdc9457SAndroid Build Coastguard Worker b.data(),
1380*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1381*4bdc9457SAndroid Build Coastguard Worker 0,
1382*4bdc9457SAndroid Build Coastguard Worker nullptr);
1383*4bdc9457SAndroid Build Coastguard Worker
1384*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1385*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1386*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1387*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1388*4bdc9457SAndroid Build Coastguard Worker 2.0f, 6.0f,
1389*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1390*4bdc9457SAndroid Build Coastguard Worker 4.0f, 8.0f, 3.0f, 7.0f, 5.0f, 9.0f,
1391*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1392*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1393*4bdc9457SAndroid Build Coastguard Worker };
1394*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1395*4bdc9457SAndroid Build Coastguard Worker }
1396*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_GHW_W,primary_tile_gt_kernel_size_channels_gt_cr)1397*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_GHW_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1398*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1399*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1400*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1401*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1402*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1403*4bdc9457SAndroid Build Coastguard Worker
1404*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1405*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1406*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1407*4bdc9457SAndroid Build Coastguard Worker // 5, 6,
1408*4bdc9457SAndroid Build Coastguard Worker // 7, 8,
1409*4bdc9457SAndroid Build Coastguard Worker // 9, 10,
1410*4bdc9457SAndroid Build Coastguard Worker // 11, 12,
1411*4bdc9457SAndroid Build Coastguard Worker // 13, 14,
1412*4bdc9457SAndroid Build Coastguard Worker // 15, 16,
1413*4bdc9457SAndroid Build Coastguard Worker // 17, 18,
1414*4bdc9457SAndroid Build Coastguard Worker // 19, 20,
1415*4bdc9457SAndroid Build Coastguard Worker // 21, 22,
1416*4bdc9457SAndroid Build Coastguard Worker // 23, 24]
1417*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1418*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1419*4bdc9457SAndroid Build Coastguard Worker
1420*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_ghw_w(
1421*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1422*4bdc9457SAndroid Build Coastguard Worker h,
1423*4bdc9457SAndroid Build Coastguard Worker w,
1424*4bdc9457SAndroid Build Coastguard Worker c,
1425*4bdc9457SAndroid Build Coastguard Worker cr,
1426*4bdc9457SAndroid Build Coastguard Worker k.data(),
1427*4bdc9457SAndroid Build Coastguard Worker b.data(),
1428*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1429*4bdc9457SAndroid Build Coastguard Worker 0,
1430*4bdc9457SAndroid Build Coastguard Worker nullptr);
1431*4bdc9457SAndroid Build Coastguard Worker
1432*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1433*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1434*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1435*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1436*4bdc9457SAndroid Build Coastguard Worker 5.0f, 9.0f,
1437*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1438*4bdc9457SAndroid Build Coastguard Worker 7.0f, 11.0f,
1439*4bdc9457SAndroid Build Coastguard Worker 6.0f, 10.0f,
1440*4bdc9457SAndroid Build Coastguard Worker 8.0f, 12.0f,
1441*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1442*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1443*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1444*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1445*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1446*4bdc9457SAndroid Build Coastguard Worker 13.0f, 17.0f, 15.0f, 19.0f, 14.0f, 18.0f, 16.0f, 20.0f,
1447*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1448*4bdc9457SAndroid Build Coastguard Worker // bias
1449*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1450*4bdc9457SAndroid Build Coastguard Worker // weights
1451*4bdc9457SAndroid Build Coastguard Worker 21.0f, 0.0f, 23.0f, 0.0f, 22.0f, 0.0f, 24.0f, 0.0f,
1452*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1453*4bdc9457SAndroid Build Coastguard Worker };
1454*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1455*4bdc9457SAndroid Build Coastguard Worker }
1456*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_HWG_W,primary_tile_eq_kernel_size)1457*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_HWG_W, primary_tile_eq_kernel_size) {
1458*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1459*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1460*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1461*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1462*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1463*4bdc9457SAndroid Build Coastguard Worker
1464*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1465*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1466*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
1467*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1468*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1469*4bdc9457SAndroid Build Coastguard Worker
1470*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_hwg_w(
1471*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1472*4bdc9457SAndroid Build Coastguard Worker h,
1473*4bdc9457SAndroid Build Coastguard Worker w,
1474*4bdc9457SAndroid Build Coastguard Worker c,
1475*4bdc9457SAndroid Build Coastguard Worker cr,
1476*4bdc9457SAndroid Build Coastguard Worker k.data(),
1477*4bdc9457SAndroid Build Coastguard Worker b.data(),
1478*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1479*4bdc9457SAndroid Build Coastguard Worker 0,
1480*4bdc9457SAndroid Build Coastguard Worker nullptr);
1481*4bdc9457SAndroid Build Coastguard Worker
1482*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1483*4bdc9457SAndroid Build Coastguard Worker // bias first
1484*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1485*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1486*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1487*4bdc9457SAndroid Build Coastguard Worker 4.0f, 5.0f,
1488*4bdc9457SAndroid Build Coastguard Worker 6.0f, 7.0f,
1489*4bdc9457SAndroid Build Coastguard Worker };
1490*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1491*4bdc9457SAndroid Build Coastguard Worker }
1492*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_HWG_W,primary_tile_eq_kernel_size_channels_gt_cr)1493*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_HWG_W, primary_tile_eq_kernel_size_channels_gt_cr) {
1494*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1495*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1496*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1497*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1498*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1499*4bdc9457SAndroid Build Coastguard Worker
1500*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1501*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1502*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1503*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1504*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1505*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19]
1506*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1507*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1508*4bdc9457SAndroid Build Coastguard Worker
1509*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_hwg_w(
1510*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1511*4bdc9457SAndroid Build Coastguard Worker h,
1512*4bdc9457SAndroid Build Coastguard Worker w,
1513*4bdc9457SAndroid Build Coastguard Worker c,
1514*4bdc9457SAndroid Build Coastguard Worker cr,
1515*4bdc9457SAndroid Build Coastguard Worker k.data(),
1516*4bdc9457SAndroid Build Coastguard Worker b.data(),
1517*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1518*4bdc9457SAndroid Build Coastguard Worker 0,
1519*4bdc9457SAndroid Build Coastguard Worker nullptr);
1520*4bdc9457SAndroid Build Coastguard Worker
1521*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1522*4bdc9457SAndroid Build Coastguard Worker // cr blocks
1523*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1524*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1525*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1526*4bdc9457SAndroid Build Coastguard Worker 5.0f, 6.0f, 10.0f, 11.0f, 15.0f, 16.0f,
1527*4bdc9457SAndroid Build Coastguard Worker // bias again
1528*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1529*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1530*4bdc9457SAndroid Build Coastguard Worker 7.0f, 8.0f, 12.0f, 13.0f, 17.0f, 18.0f,
1531*4bdc9457SAndroid Build Coastguard Worker // bias again
1532*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1533*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1534*4bdc9457SAndroid Build Coastguard Worker 9.0f, 0.0f, 14.0f, 0.0f, 19.0f, 0.0f,
1535*4bdc9457SAndroid Build Coastguard Worker };
1536*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1537*4bdc9457SAndroid Build Coastguard Worker }
1538*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_HWG_W,primary_tile_gt_kernel_size)1539*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_HWG_W, primary_tile_gt_kernel_size) {
1540*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1541*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1542*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1543*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1544*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1545*4bdc9457SAndroid Build Coastguard Worker
1546*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1547*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1548*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1549*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1550*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1551*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1552*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1553*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1554*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1555*4bdc9457SAndroid Build Coastguard Worker
1556*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_hwg_w(
1557*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1558*4bdc9457SAndroid Build Coastguard Worker h,
1559*4bdc9457SAndroid Build Coastguard Worker w,
1560*4bdc9457SAndroid Build Coastguard Worker c,
1561*4bdc9457SAndroid Build Coastguard Worker cr,
1562*4bdc9457SAndroid Build Coastguard Worker k.data(),
1563*4bdc9457SAndroid Build Coastguard Worker b.data(),
1564*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1565*4bdc9457SAndroid Build Coastguard Worker 0,
1566*4bdc9457SAndroid Build Coastguard Worker nullptr);
1567*4bdc9457SAndroid Build Coastguard Worker
1568*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1569*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1570*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1571*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1572*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1573*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1574*4bdc9457SAndroid Build Coastguard Worker 6.0f, 7.0f, 4.0f, 5.0f, 8.0f, 9.0f,
1575*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1576*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1577*4bdc9457SAndroid Build Coastguard Worker };
1578*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1579*4bdc9457SAndroid Build Coastguard Worker }
1580*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_DWCONV_HWG_W,primary_tile_gt_kernel_size_channels_gt_cr)1581*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_DWCONV_HWG_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1582*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1583*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1584*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1585*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1586*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1587*4bdc9457SAndroid Build Coastguard Worker
1588*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1589*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1590*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1591*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1592*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1593*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19,
1594*4bdc9457SAndroid Build Coastguard Worker // 20, 21, 22, 23, 24]
1595*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1596*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1597*4bdc9457SAndroid Build Coastguard Worker
1598*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dwconv_hwg_w(
1599*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1600*4bdc9457SAndroid Build Coastguard Worker h,
1601*4bdc9457SAndroid Build Coastguard Worker w,
1602*4bdc9457SAndroid Build Coastguard Worker c,
1603*4bdc9457SAndroid Build Coastguard Worker cr,
1604*4bdc9457SAndroid Build Coastguard Worker k.data(),
1605*4bdc9457SAndroid Build Coastguard Worker b.data(),
1606*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1607*4bdc9457SAndroid Build Coastguard Worker 0,
1608*4bdc9457SAndroid Build Coastguard Worker nullptr);
1609*4bdc9457SAndroid Build Coastguard Worker
1610*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected = {
1611*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1612*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1613*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1614*4bdc9457SAndroid Build Coastguard Worker 5.0f, 6.0f,
1615*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1616*4bdc9457SAndroid Build Coastguard Worker 15.0f, 16.0f,
1617*4bdc9457SAndroid Build Coastguard Worker 10.0f, 11.0f,
1618*4bdc9457SAndroid Build Coastguard Worker 20.0f, 21.0f,
1619*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1620*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1621*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1622*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1623*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1624*4bdc9457SAndroid Build Coastguard Worker 7.0f, 8.0f, 17.0f, 18.0f, 12.0f, 13.0f, 22.0f, 23.0f,
1625*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1626*4bdc9457SAndroid Build Coastguard Worker // bias
1627*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1628*4bdc9457SAndroid Build Coastguard Worker // weights
1629*4bdc9457SAndroid Build Coastguard Worker 9.0f, 0.0f, 19.0f, 0.0f, 14.0f, 0.0f, 24.0f, 0.0f,
1630*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1631*4bdc9457SAndroid Build Coastguard Worker };
1632*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1633*4bdc9457SAndroid Build Coastguard Worker }
1634*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_GHW_W,primary_tile_eq_kernel_size)1635*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_GHW_W, primary_tile_eq_kernel_size) {
1636*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1637*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1638*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1639*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1640*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1641*4bdc9457SAndroid Build Coastguard Worker
1642*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1643*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1644*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
1645*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1646*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1647*4bdc9457SAndroid Build Coastguard Worker
1648*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_ghw_w(
1649*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1650*4bdc9457SAndroid Build Coastguard Worker h,
1651*4bdc9457SAndroid Build Coastguard Worker w,
1652*4bdc9457SAndroid Build Coastguard Worker c,
1653*4bdc9457SAndroid Build Coastguard Worker cr,
1654*4bdc9457SAndroid Build Coastguard Worker k.data(),
1655*4bdc9457SAndroid Build Coastguard Worker b.data(),
1656*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1657*4bdc9457SAndroid Build Coastguard Worker 0,
1658*4bdc9457SAndroid Build Coastguard Worker nullptr);
1659*4bdc9457SAndroid Build Coastguard Worker
1660*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1661*4bdc9457SAndroid Build Coastguard Worker // bias first
1662*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1663*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1664*4bdc9457SAndroid Build Coastguard Worker 2.0f, 5.0f,
1665*4bdc9457SAndroid Build Coastguard Worker 3.0f, 6.0f,
1666*4bdc9457SAndroid Build Coastguard Worker 4.0f, 7.0f,
1667*4bdc9457SAndroid Build Coastguard Worker };
1668*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1669*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1670*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1671*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1672*4bdc9457SAndroid Build Coastguard Worker }
1673*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_GHW_W,primary_tile_eq_kernel_size_channels_gt_cr)1674*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_GHW_W, primary_tile_eq_kernel_size_channels_gt_cr) {
1675*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1676*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1677*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1678*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1679*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1680*4bdc9457SAndroid Build Coastguard Worker
1681*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1682*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1683*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1684*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7,
1685*4bdc9457SAndroid Build Coastguard Worker // 8, 9, 10,
1686*4bdc9457SAndroid Build Coastguard Worker // 11, 12, 13,
1687*4bdc9457SAndroid Build Coastguard Worker // 14, 15, 16,
1688*4bdc9457SAndroid Build Coastguard Worker // 17, 18, 19]
1689*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1690*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1691*4bdc9457SAndroid Build Coastguard Worker
1692*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_ghw_w(
1693*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1694*4bdc9457SAndroid Build Coastguard Worker h,
1695*4bdc9457SAndroid Build Coastguard Worker w,
1696*4bdc9457SAndroid Build Coastguard Worker c,
1697*4bdc9457SAndroid Build Coastguard Worker cr,
1698*4bdc9457SAndroid Build Coastguard Worker k.data(),
1699*4bdc9457SAndroid Build Coastguard Worker b.data(),
1700*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1701*4bdc9457SAndroid Build Coastguard Worker 0,
1702*4bdc9457SAndroid Build Coastguard Worker nullptr);
1703*4bdc9457SAndroid Build Coastguard Worker
1704*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1705*4bdc9457SAndroid Build Coastguard Worker // cr blocks
1706*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1707*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1708*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1709*4bdc9457SAndroid Build Coastguard Worker 5.0f, 8.0f, 6.0f, 9.0f, 7.0f, 10.0f,
1710*4bdc9457SAndroid Build Coastguard Worker // bias again
1711*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1712*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1713*4bdc9457SAndroid Build Coastguard Worker 11.0f, 14.0f, 12.0f, 15.0f, 13.0f, 16.0f,
1714*4bdc9457SAndroid Build Coastguard Worker // bias again
1715*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1716*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1717*4bdc9457SAndroid Build Coastguard Worker 17.0f, 0.0f, 18.0f, 0.0f, 19.0f, 0.0f,
1718*4bdc9457SAndroid Build Coastguard Worker };
1719*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1720*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1721*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1722*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1723*4bdc9457SAndroid Build Coastguard Worker }
1724*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_GHW_W,primary_tile_gt_kernel_size)1725*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_GHW_W, primary_tile_gt_kernel_size) {
1726*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1727*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1728*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1729*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1730*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1731*4bdc9457SAndroid Build Coastguard Worker
1732*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1733*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1734*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1735*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1736*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1737*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1738*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1739*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1740*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1741*4bdc9457SAndroid Build Coastguard Worker
1742*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_ghw_w(
1743*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1744*4bdc9457SAndroid Build Coastguard Worker h,
1745*4bdc9457SAndroid Build Coastguard Worker w,
1746*4bdc9457SAndroid Build Coastguard Worker c,
1747*4bdc9457SAndroid Build Coastguard Worker cr,
1748*4bdc9457SAndroid Build Coastguard Worker k.data(),
1749*4bdc9457SAndroid Build Coastguard Worker b.data(),
1750*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1751*4bdc9457SAndroid Build Coastguard Worker 0,
1752*4bdc9457SAndroid Build Coastguard Worker nullptr);
1753*4bdc9457SAndroid Build Coastguard Worker
1754*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1755*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1756*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1757*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1758*4bdc9457SAndroid Build Coastguard Worker 2.0f, 6.0f,
1759*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1760*4bdc9457SAndroid Build Coastguard Worker 4.0f, 8.0f, 3.0f, 7.0f, 5.0f, 9.0f,
1761*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1762*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1763*4bdc9457SAndroid Build Coastguard Worker };
1764*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1765*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1766*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1767*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1768*4bdc9457SAndroid Build Coastguard Worker }
1769*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_GHW_W,primary_tile_gt_kernel_size_channels_gt_cr)1770*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_GHW_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1771*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1772*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1773*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1774*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1775*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1776*4bdc9457SAndroid Build Coastguard Worker
1777*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1778*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1779*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1780*4bdc9457SAndroid Build Coastguard Worker // 5, 6,
1781*4bdc9457SAndroid Build Coastguard Worker // 7, 8,
1782*4bdc9457SAndroid Build Coastguard Worker // 9, 10,
1783*4bdc9457SAndroid Build Coastguard Worker // 11, 12,
1784*4bdc9457SAndroid Build Coastguard Worker // 13, 14,
1785*4bdc9457SAndroid Build Coastguard Worker // 15, 16,
1786*4bdc9457SAndroid Build Coastguard Worker // 17, 18,
1787*4bdc9457SAndroid Build Coastguard Worker // 19, 20,
1788*4bdc9457SAndroid Build Coastguard Worker // 21, 22,
1789*4bdc9457SAndroid Build Coastguard Worker // 23, 24]
1790*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1791*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1792*4bdc9457SAndroid Build Coastguard Worker
1793*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_ghw_w(
1794*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1795*4bdc9457SAndroid Build Coastguard Worker h,
1796*4bdc9457SAndroid Build Coastguard Worker w,
1797*4bdc9457SAndroid Build Coastguard Worker c,
1798*4bdc9457SAndroid Build Coastguard Worker cr,
1799*4bdc9457SAndroid Build Coastguard Worker k.data(),
1800*4bdc9457SAndroid Build Coastguard Worker b.data(),
1801*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1802*4bdc9457SAndroid Build Coastguard Worker 0,
1803*4bdc9457SAndroid Build Coastguard Worker nullptr);
1804*4bdc9457SAndroid Build Coastguard Worker
1805*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1806*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1807*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1808*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1809*4bdc9457SAndroid Build Coastguard Worker 5.0f, 9.0f,
1810*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1811*4bdc9457SAndroid Build Coastguard Worker 7.0f, 11.0f,
1812*4bdc9457SAndroid Build Coastguard Worker 6.0f, 10.0f,
1813*4bdc9457SAndroid Build Coastguard Worker 8.0f, 12.0f,
1814*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1815*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1816*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1817*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1818*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1819*4bdc9457SAndroid Build Coastguard Worker 13.0f, 17.0f, 15.0f, 19.0f, 14.0f, 18.0f, 16.0f, 20.0f,
1820*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1821*4bdc9457SAndroid Build Coastguard Worker // bias
1822*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1823*4bdc9457SAndroid Build Coastguard Worker // weights
1824*4bdc9457SAndroid Build Coastguard Worker 21.0f, 0.0f, 23.0f, 0.0f, 22.0f, 0.0f, 24.0f, 0.0f,
1825*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1826*4bdc9457SAndroid Build Coastguard Worker };
1827*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1828*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1829*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1830*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1831*4bdc9457SAndroid Build Coastguard Worker }
1832*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_HWG_W,primary_tile_eq_kernel_size)1833*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_HWG_W, primary_tile_eq_kernel_size) {
1834*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1835*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1836*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1837*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1838*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1839*4bdc9457SAndroid Build Coastguard Worker
1840*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1841*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1842*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [2, 3, 4, 5, 6, 7]
1843*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1844*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1845*4bdc9457SAndroid Build Coastguard Worker
1846*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_hwg_w(
1847*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1848*4bdc9457SAndroid Build Coastguard Worker h,
1849*4bdc9457SAndroid Build Coastguard Worker w,
1850*4bdc9457SAndroid Build Coastguard Worker c,
1851*4bdc9457SAndroid Build Coastguard Worker cr,
1852*4bdc9457SAndroid Build Coastguard Worker k.data(),
1853*4bdc9457SAndroid Build Coastguard Worker b.data(),
1854*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1855*4bdc9457SAndroid Build Coastguard Worker 0,
1856*4bdc9457SAndroid Build Coastguard Worker nullptr);
1857*4bdc9457SAndroid Build Coastguard Worker
1858*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1859*4bdc9457SAndroid Build Coastguard Worker // bias first
1860*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1861*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1862*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1863*4bdc9457SAndroid Build Coastguard Worker 4.0f, 5.0f,
1864*4bdc9457SAndroid Build Coastguard Worker 6.0f, 7.0f,
1865*4bdc9457SAndroid Build Coastguard Worker };
1866*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1867*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1868*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1869*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1870*4bdc9457SAndroid Build Coastguard Worker }
1871*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_HWG_W,primary_tile_eq_kernel_size_channels_gt_cr)1872*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_HWG_W, primary_tile_eq_kernel_size_channels_gt_cr) {
1873*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 3;
1874*4bdc9457SAndroid Build Coastguard Worker size_t h = 3;
1875*4bdc9457SAndroid Build Coastguard Worker size_t w = 1;
1876*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1877*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1878*4bdc9457SAndroid Build Coastguard Worker
1879*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1880*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1881*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1882*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1883*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1884*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19]
1885*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1886*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1887*4bdc9457SAndroid Build Coastguard Worker
1888*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_hwg_w(
1889*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1890*4bdc9457SAndroid Build Coastguard Worker h,
1891*4bdc9457SAndroid Build Coastguard Worker w,
1892*4bdc9457SAndroid Build Coastguard Worker c,
1893*4bdc9457SAndroid Build Coastguard Worker cr,
1894*4bdc9457SAndroid Build Coastguard Worker k.data(),
1895*4bdc9457SAndroid Build Coastguard Worker b.data(),
1896*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1897*4bdc9457SAndroid Build Coastguard Worker 0,
1898*4bdc9457SAndroid Build Coastguard Worker nullptr);
1899*4bdc9457SAndroid Build Coastguard Worker
1900*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1901*4bdc9457SAndroid Build Coastguard Worker // cr blocks
1902*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1903*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1904*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1905*4bdc9457SAndroid Build Coastguard Worker 5.0f, 6.0f, 10.0f, 11.0f, 15.0f, 16.0f,
1906*4bdc9457SAndroid Build Coastguard Worker // bias again
1907*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1908*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1909*4bdc9457SAndroid Build Coastguard Worker 7.0f, 8.0f, 12.0f, 13.0f, 17.0f, 18.0f,
1910*4bdc9457SAndroid Build Coastguard Worker // bias again
1911*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
1912*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1913*4bdc9457SAndroid Build Coastguard Worker 9.0f, 0.0f, 14.0f, 0.0f, 19.0f, 0.0f,
1914*4bdc9457SAndroid Build Coastguard Worker };
1915*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1916*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1917*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1918*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1919*4bdc9457SAndroid Build Coastguard Worker }
1920*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_HWG_W,primary_tile_gt_kernel_size)1921*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_HWG_W, primary_tile_gt_kernel_size) {
1922*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1923*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1924*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1925*4bdc9457SAndroid Build Coastguard Worker size_t c = 2;
1926*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1927*4bdc9457SAndroid Build Coastguard Worker
1928*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1929*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1]
1930*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1931*4bdc9457SAndroid Build Coastguard Worker // 2, 3,
1932*4bdc9457SAndroid Build Coastguard Worker // 4, 5,
1933*4bdc9457SAndroid Build Coastguard Worker // 6, 7,
1934*4bdc9457SAndroid Build Coastguard Worker // 8, 9]
1935*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1936*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1937*4bdc9457SAndroid Build Coastguard Worker
1938*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_hwg_w(
1939*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1940*4bdc9457SAndroid Build Coastguard Worker h,
1941*4bdc9457SAndroid Build Coastguard Worker w,
1942*4bdc9457SAndroid Build Coastguard Worker c,
1943*4bdc9457SAndroid Build Coastguard Worker cr,
1944*4bdc9457SAndroid Build Coastguard Worker k.data(),
1945*4bdc9457SAndroid Build Coastguard Worker b.data(),
1946*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1947*4bdc9457SAndroid Build Coastguard Worker 0,
1948*4bdc9457SAndroid Build Coastguard Worker nullptr);
1949*4bdc9457SAndroid Build Coastguard Worker
1950*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1951*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1952*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1953*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1954*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
1955*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
1956*4bdc9457SAndroid Build Coastguard Worker 6.0f, 7.0f, 4.0f, 5.0f, 8.0f, 9.0f,
1957*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
1958*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1959*4bdc9457SAndroid Build Coastguard Worker };
1960*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
1961*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
1962*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
1963*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
1964*4bdc9457SAndroid Build Coastguard Worker }
1965*4bdc9457SAndroid Build Coastguard Worker
TEST(PACK_F32_TO_F16_DWCONV_HWG_W,primary_tile_gt_kernel_size_channels_gt_cr)1966*4bdc9457SAndroid Build Coastguard Worker TEST(PACK_F32_TO_F16_DWCONV_HWG_W, primary_tile_gt_kernel_size_channels_gt_cr) {
1967*4bdc9457SAndroid Build Coastguard Worker size_t primary_tile = 9;
1968*4bdc9457SAndroid Build Coastguard Worker size_t h = 2;
1969*4bdc9457SAndroid Build Coastguard Worker size_t w = 2;
1970*4bdc9457SAndroid Build Coastguard Worker size_t c = 5;
1971*4bdc9457SAndroid Build Coastguard Worker size_t cr = 2;
1972*4bdc9457SAndroid Build Coastguard Worker
1973*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(c);
1974*4bdc9457SAndroid Build Coastguard Worker std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4]
1975*4bdc9457SAndroid Build Coastguard Worker std::vector<float> k(c * h * w); // k = [
1976*4bdc9457SAndroid Build Coastguard Worker // 5, 6, 7, 8, 9,
1977*4bdc9457SAndroid Build Coastguard Worker // 10, 11, 12, 13, 14,
1978*4bdc9457SAndroid Build Coastguard Worker // 15, 16, 17, 18, 19,
1979*4bdc9457SAndroid Build Coastguard Worker // 20, 21, 22, 23, 24]
1980*4bdc9457SAndroid Build Coastguard Worker std::iota(k.begin(), k.end(), b.size());
1981*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(((primary_tile + 1) * round_up_po2(c, cr)));
1982*4bdc9457SAndroid Build Coastguard Worker
1983*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_to_f16_dwconv_hwg_w(
1984*4bdc9457SAndroid Build Coastguard Worker primary_tile,
1985*4bdc9457SAndroid Build Coastguard Worker h,
1986*4bdc9457SAndroid Build Coastguard Worker w,
1987*4bdc9457SAndroid Build Coastguard Worker c,
1988*4bdc9457SAndroid Build Coastguard Worker cr,
1989*4bdc9457SAndroid Build Coastguard Worker k.data(),
1990*4bdc9457SAndroid Build Coastguard Worker b.data(),
1991*4bdc9457SAndroid Build Coastguard Worker packed_weights.data(),
1992*4bdc9457SAndroid Build Coastguard Worker 0,
1993*4bdc9457SAndroid Build Coastguard Worker nullptr);
1994*4bdc9457SAndroid Build Coastguard Worker
1995*4bdc9457SAndroid Build Coastguard Worker std::vector<float> expected_float = {
1996*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
1997*4bdc9457SAndroid Build Coastguard Worker 0.0f, 1.0f,
1998*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
1999*4bdc9457SAndroid Build Coastguard Worker 5.0f, 6.0f,
2000*4bdc9457SAndroid Build Coastguard Worker // go down the columns first
2001*4bdc9457SAndroid Build Coastguard Worker 15.0f, 16.0f,
2002*4bdc9457SAndroid Build Coastguard Worker 10.0f, 11.0f,
2003*4bdc9457SAndroid Build Coastguard Worker 20.0f, 21.0f,
2004*4bdc9457SAndroid Build Coastguard Worker // followed by 10 zeros to make up the difference with primary_tile
2005*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2006*4bdc9457SAndroid Build Coastguard Worker // bias first (cr == 2 of them)
2007*4bdc9457SAndroid Build Coastguard Worker 2.0f, 3.0f,
2008*4bdc9457SAndroid Build Coastguard Worker // then weights, channels first
2009*4bdc9457SAndroid Build Coastguard Worker 7.0f, 8.0f, 17.0f, 18.0f, 12.0f, 13.0f, 22.0f, 23.0f,
2010*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2011*4bdc9457SAndroid Build Coastguard Worker // bias
2012*4bdc9457SAndroid Build Coastguard Worker 4.0f, 0.0f,
2013*4bdc9457SAndroid Build Coastguard Worker // weights
2014*4bdc9457SAndroid Build Coastguard Worker 9.0f, 0.0f, 19.0f, 0.0f, 14.0f, 0.0f, 24.0f, 0.0f,
2015*4bdc9457SAndroid Build Coastguard Worker 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2016*4bdc9457SAndroid Build Coastguard Worker };
2017*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> expected(expected_float.size());
2018*4bdc9457SAndroid Build Coastguard Worker std::transform(expected_float.begin(), expected_float.end(), expected.begin(),
2019*4bdc9457SAndroid Build Coastguard Worker [](float f) { return fp16_ieee_from_fp32_value(f); });
2020*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected, packed_weights);
2021*4bdc9457SAndroid Build Coastguard Worker }
2022