xref: /aosp_15_r20/external/XNNPACK/bench/spmm.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <benchmark/benchmark.h>
9 
10 #define BENCHMARK_SPMM(spmm_fn) \
11   BENCHMARK_CAPTURE(spmm_fn, mobilenet_v1, "MobileNet v1")->Apply(MobileNetV1SpmmArguments)->UseRealTime(); \
12   BENCHMARK_CAPTURE(spmm_fn, mobilenet_v2, "MobileNet v2")->Apply(MobileNetV2SpmmArguments)->UseRealTime(); \
13   BENCHMARK_CAPTURE(spmm_fn, mobilenet_v3_small, "MobileNet v3 Small")->Apply(MobileNetV3SmallSpmmArguments)->UseRealTime(); \
14   BENCHMARK_CAPTURE(spmm_fn, mobilenet_v3_large, "MobileNet v3 Large")->Apply(MobileNetV3LargeSpmmArguments)->UseRealTime(); \
15   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1SpmmArguments)->UseRealTime(); \
16   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2SpmmArguments)->UseRealTime(); \
17   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3SpmmArguments)->UseRealTime(); \
18   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4SpmmArguments)->UseRealTime(); \
19   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8SpmmArguments)->UseRealTime(); \
20   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v2_x05, "ShuffleNet v2 0.5X")->Apply(ShuffleNetV2X05SpmmArguments)->UseRealTime(); \
21   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v2_x10, "ShuffleNet v2 1.0X")->Apply(ShuffleNetV2X10SpmmArguments)->UseRealTime(); \
22   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v2_x15, "ShuffleNet v2 1.5X")->Apply(ShuffleNetV2X15SpmmArguments)->UseRealTime(); \
23   BENCHMARK_CAPTURE(spmm_fn, shufflenet_v2_x20, "ShuffleNet v2 2.0X")->Apply(ShuffleNetV2X20SpmmArguments)->UseRealTime();
24 
25 
26 // ShuffleNet v1 with 1 group.
ShuffleNetV1G1SpmmArguments(benchmark::internal::Benchmark * b)27 static void ShuffleNetV1G1SpmmArguments(benchmark::internal::Benchmark* b) {
28   b->ArgNames({"M", "N", "K"});
29 
30   /*          M      N    K */
31   b->Args({56 * 56,  36,  24});
32   b->Args({28 * 28, 120,  36});
33   b->Args({28 * 28,  36, 144});
34   b->Args({28 * 28, 144,  36});
35   b->Args({28 * 28,  72, 144});
36   b->Args({14 * 14, 144,  72});
37   b->Args({14 * 14,  72, 288});
38   b->Args({14 * 14, 288,  72});
39   b->Args({14 * 14, 144, 288});
40   b->Args({ 7 *  7, 288, 144});
41   b->Args({ 7 *  7, 144, 576});
42   b->Args({ 7 *  7, 576, 144});
43 }
44 
45 // ShuffleNet v1 with 2 groups.
ShuffleNetV1G2SpmmArguments(benchmark::internal::Benchmark * b)46 static void ShuffleNetV1G2SpmmArguments(benchmark::internal::Benchmark* b) {
47   b->ArgNames({"M", "N", "K"});
48 
49   /*          M      N    K */
50   b->Args({56 * 56,  50,  24});
51   b->Args({28 * 28,  88,  25});
52   b->Args({28 * 28,  25, 100});
53   b->Args({28 * 28, 100,  25});
54   b->Args({28 * 28,  50, 100});
55   b->Args({14 * 14, 100,  50});
56   b->Args({14 * 14,  50, 200});
57   b->Args({14 * 14, 200,  50});
58   b->Args({14 * 14, 100, 200});
59   b->Args({ 7 *  7, 200, 100});
60   b->Args({ 7 *  7, 100, 400});
61   b->Args({ 7 *  7, 400, 100});
62 }
63 
64 // ShuffleNet v1 with 3 groups.
ShuffleNetV1G3SpmmArguments(benchmark::internal::Benchmark * b)65 static void ShuffleNetV1G3SpmmArguments(benchmark::internal::Benchmark* b) {
66   b->ArgNames({"M", "N", "K"});
67 
68   /*          M      N    K */
69   b->Args({56 * 56,  60,  24});
70   b->Args({28 * 28,  72,  20});
71   b->Args({28 * 28,  20,  80});
72   b->Args({28 * 28,  80,  20});
73   b->Args({28 * 28,  40,  80});
74   b->Args({14 * 14,  80,  40});
75   b->Args({14 * 14,  40, 160});
76   b->Args({14 * 14, 160,  40});
77   b->Args({14 * 14,  80, 160});
78   b->Args({ 7 *  7, 160,  80});
79   b->Args({ 7 *  7,  80, 320});
80   b->Args({ 7 *  7, 320,  80});
81 }
82 
83 // ShuffleNet v1 with 4 groups.
ShuffleNetV1G4SpmmArguments(benchmark::internal::Benchmark * b)84 static void ShuffleNetV1G4SpmmArguments(benchmark::internal::Benchmark* b) {
85   b->ArgNames({"M", "N", "K"});
86 
87   /*          M      N    K */
88   b->Args({56 * 56,  68,  24});
89   b->Args({28 * 28,  62,  17});
90   b->Args({28 * 28,  17,  68});
91   b->Args({28 * 28,  68,  17});
92   b->Args({28 * 28,  34,  68});
93   b->Args({14 * 14,  68,  34});
94   b->Args({14 * 14,  34, 136});
95   b->Args({14 * 14, 136,  34});
96   b->Args({14 * 14,  68, 136});
97   b->Args({ 7 *  7, 136,  68});
98   b->Args({ 7 *  7,  68, 272});
99   b->Args({ 7 *  7, 272,  68});
100 }
101 
102 // ShuffleNet v1 with 8 groups.
ShuffleNetV1G8SpmmArguments(benchmark::internal::Benchmark * b)103 static void ShuffleNetV1G8SpmmArguments(benchmark::internal::Benchmark* b) {
104   b->ArgNames({"M", "N", "K"});
105 
106   /*          M      N    K */
107   b->Args({56 * 56,  96,  24});
108   b->Args({28 * 28,  45,  12});
109   b->Args({28 * 28,  12,  48});
110   b->Args({28 * 28,  48,  12});
111   b->Args({28 * 28,  24,  48});
112   b->Args({14 * 14,  48,  24});
113   b->Args({14 * 14,  24,  96});
114   b->Args({14 * 14,  96,  24});
115   b->Args({14 * 14,  48,  96});
116   b->Args({ 7 *  7,  96,  48});
117   b->Args({ 7 *  7,  48, 192});
118   b->Args({ 7 *  7, 192,  48});
119 }
120 
121 // ShuffleNet v2 (0.5X scale)
ShuffleNetV2X05SpmmArguments(benchmark::internal::Benchmark * b)122 static void ShuffleNetV2X05SpmmArguments(benchmark::internal::Benchmark* b) {
123   b->ArgNames({"M", "N", "K"});
124 
125   /*          M       N    K */
126   b->Args({56 * 56,   24,  24});
127   b->Args({28 * 28,   24,  24});
128   b->Args({28 * 28,   48,  48});
129   b->Args({14 * 14,   48,  48});
130   b->Args({14 * 14,   96,  96});
131   b->Args({ 7 *  7,   96,  96});
132   b->Args({ 7 *  7, 1024, 192});
133 }
134 
135 // ShuffleNet v2 (1.0X scale)
ShuffleNetV2X10SpmmArguments(benchmark::internal::Benchmark * b)136 static void ShuffleNetV2X10SpmmArguments(benchmark::internal::Benchmark* b) {
137   b->ArgNames({"M", "N", "K"});
138 
139   /*          M       N    K */
140   b->Args({56 * 56,   58,  24});
141   b->Args({28 * 28,   58,  24});
142   b->Args({28 * 28,   58,  58});
143   b->Args({14 * 14,  116, 116});
144   b->Args({14 * 14,  116, 116});
145   b->Args({14 * 14,  232, 232});
146   b->Args({ 7 *  7,  232, 232});
147   b->Args({ 7 *  7, 1024, 464});
148 }
149 
150 // ShuffleNet v2 (1.5X scale)
ShuffleNetV2X15SpmmArguments(benchmark::internal::Benchmark * b)151 static void ShuffleNetV2X15SpmmArguments(benchmark::internal::Benchmark* b) {
152   b->ArgNames({"M", "N", "K"});
153 
154   /*          M       N    K */
155   b->Args({56 * 56,   88,  24});
156   b->Args({28 * 28,   88,  24});
157   b->Args({28 * 28,   88,  88});
158   b->Args({28 * 28,  176, 176});
159   b->Args({14 * 14,  176, 176});
160   b->Args({14 * 14,  352, 352});
161   b->Args({ 7 *  7,  352, 352});
162   b->Args({ 7 *  7, 1024, 704});
163 }
164 
165 // ShuffleNet v2 (2.0X scale)
ShuffleNetV2X20SpmmArguments(benchmark::internal::Benchmark * b)166 static void ShuffleNetV2X20SpmmArguments(benchmark::internal::Benchmark* b) {
167   b->ArgNames({"M", "N", "K"});
168 
169   /*          M       N    K */
170   b->Args({56 * 56,  122,  24});
171   b->Args({28 * 28,  122,  24});
172   b->Args({28 * 28,  122, 122});
173   b->Args({28 * 28,  244, 244});
174   b->Args({14 * 14,  244, 244});
175   b->Args({14 * 14,  488, 488});
176   b->Args({ 7 *  7,  488, 488});
177   b->Args({ 7 *  7, 2048, 976});
178 }
179 
MobileNetV1SpmmArguments(benchmark::internal::Benchmark * b)180 static void MobileNetV1SpmmArguments(benchmark::internal::Benchmark* b) {
181   b->ArgNames({"M", "N", "K"});
182 
183   /*           M        N     K */
184   b->Args({112 * 112,   64,   32});
185   b->Args({ 56 *  56,  128,   64});
186   b->Args({ 56 *  56,  128,  128});
187   b->Args({ 28 *  28,  256,  128});
188   b->Args({ 28 *  28,  256,  256});
189   b->Args({ 14 *  14,  512,  256});
190   b->Args({ 14 *  14,  512,  512});
191   b->Args({  7 *   7, 1024,  512});
192   b->Args({  7 *   7, 1024, 1024});
193 }
194 
MobileNetV2SpmmArguments(benchmark::internal::Benchmark * b)195 static void MobileNetV2SpmmArguments(benchmark::internal::Benchmark* b) {
196   b->ArgNames({"M", "N", "K"});
197 
198   /******** Bottleneck 1 *******/
199   /*           M        N    K */
200   b->Args({112 * 112,   16,  32});
201   /******** Bottleneck 2 *******/
202   /*           M        N    K */
203   b->Args({112 * 112,   96,  16});
204   b->Args({ 56 *  56,   24,  96});
205   b->Args({ 56 *  56,  144,  24});
206   b->Args({ 56 *  56,   24, 144});
207   /******** Bottleneck 3 *******/
208   /*           M        N    K */
209   b->Args({ 28 *  28,   32, 144});
210   b->Args({ 28 *  28,  192,  32});
211   b->Args({ 28 *  28,   32, 192});
212   /******** Bottleneck 4 *******/
213   /*           M        N    K */
214   b->Args({ 14 *  14,   64, 192});
215   b->Args({ 14 *  14,  384,  64});
216   b->Args({ 14 *  14,   64, 384});
217   /******** Bottleneck 5 *******/
218   /*           M        N    K */
219   b->Args({ 14 *  14,   96, 384});
220   b->Args({ 14 *  14,  576,  96});
221   b->Args({ 14 *  14,   96, 576});
222   /******** Bottleneck 6 *******/
223   /*           M        N    K */
224   b->Args({  7 *   7,  160, 576});
225   b->Args({  7 *   7,  960, 160});
226   b->Args({  7 *   7,  160, 960});
227   /******** Bottleneck 7 *******/
228   /*           M        N    K */
229   b->Args({  7 *   7,  320, 960});
230   /***** Pre-pooling Conv2D ****/
231   /*           M        N    K */
232   b->Args({  7 *   7, 1280, 320});
233 }
234 
MobileNetV3SmallSpmmArguments(benchmark::internal::Benchmark * b)235 static void MobileNetV3SmallSpmmArguments(benchmark::internal::Benchmark* b) {
236   b->ArgNames({"M", "N", "K"});
237 
238   /****** Bottleneck 1 ******/
239   /*          M      N    K */
240   b->Args({ 1 *  1,   8,  16});
241   b->Args({ 1 *  1,  16,   8});
242   b->Args({56 * 56,  16,  16});
243   /****** Bottleneck 2 ******/
244   /*          M      N    K */
245   b->Args({56 * 56,  72,  16});
246   b->Args({28 * 28,  24,  72});
247   /****** Bottleneck 3 ******/
248   /*          M      N    K */
249   b->Args({28 * 28,  88,  24});
250   b->Args({28 * 28,  24,  88});
251   /****** Bottleneck 4 ******/
252   /*          M      N    K */
253   b->Args({28 * 28,  96,  24});
254   b->Args({ 1 *  1,  24,  96});
255   b->Args({ 1 *  1,  96,  24});
256   b->Args({14 * 14,  40,  96});
257   /****** Bottleneck 5 ******/
258   /*          M      N    K */
259   b->Args({14 * 14, 240,  40});
260   b->Args({ 1 *  1,  64, 240});
261   b->Args({ 1 *  1, 240,  64});
262   b->Args({14 * 14,  40, 240});
263   /****** Bottleneck 6 ******/
264   /*          M      N    K */
265 //b->Args({14 * 14, 240,  40});
266 //b->Args({ 1 *  1,  64, 240});
267 //b->Args({ 1 *  1, 240,  64});
268 //b->Args({14 * 14,  40, 240});
269   /****** Bottleneck 7 ******/
270   /*          M      N    K */
271   b->Args({14 * 14, 120,  40});
272   b->Args({ 1 *  1,  32, 120});
273   b->Args({ 1 *  1, 120,  32});
274   b->Args({14 * 14,  48, 120});
275   /****** Bottleneck 8 ******/
276   /*          M      N    K */
277   b->Args({14 * 14, 144,  48});
278   b->Args({ 1 *  1,  40, 144});
279   b->Args({ 1 *  1, 144,  40});
280   b->Args({14 * 14,  48, 144});
281   /****** Bottleneck 9 ******/
282   /*          M      N    K */
283   b->Args({14 * 14, 288,  48});
284   b->Args({ 1 *  1,  72, 288});
285   b->Args({ 1 *  1, 288,  72});
286   b->Args({ 7 *  7,  96, 288});
287   /****** Bottleneck 10 *****/
288   /*          M      N     K */
289   b->Args({ 7 *  7, 576,  96});
290   b->Args({ 1 *  1, 144, 576});
291   b->Args({ 1 *  1, 576, 144});
292   b->Args({ 7 *  7,  96, 576});
293   /****** Bottleneck 11 *****/
294   /*          M      N    K */
295 //b->Args({ 7 *  7, 576,  96});
296 //b->Args({ 1 *  1, 144, 576});
297 //b->Args({ 1 *  1, 576, 144});
298 //b->Args({ 7 *  7,  96, 576});
299   /******* Last Stage *******/
300   /*          M      N    K */
301 //b->Args({ 7 *  7, 576,  96});
302 }
303 
MobileNetV3LargeSpmmArguments(benchmark::internal::Benchmark * b)304 static void MobileNetV3LargeSpmmArguments(benchmark::internal::Benchmark* b) {
305   b->ArgNames({"M", "N", "K"});
306 
307   /******* Bottleneck 1 *******/
308   /*           M       N    K */
309   b->Args({112 * 112,  16,  16});
310   /******* Bottleneck 2 *******/
311   /*           M       N    K */
312   b->Args({112 * 112,  64,  16});
313   b->Args({ 56 *  56,  24,  64});
314   /******* Bottleneck 3 *******/
315   /*           M       N    K */
316   b->Args({ 56 *  56,  72,  24});
317   b->Args({ 56 *  56,  24,  72});
318   /******* Bottleneck 4 *******/
319   /*           M       N    K */
320 //b->Args({ 56 *  56,  72,  24});
321   b->Args({  1 *   1,  24,  72});
322   b->Args({  1 *   1,  72,  24});
323   b->Args({ 28 *  28,  40,  72});
324   /******* Bottleneck 5 *******/
325   /*           M       N    K */
326   b->Args({ 28 *  28, 120,  40});
327   b->Args({  1 *   1,  32, 120});
328   b->Args({  1 *   1, 120,  32});
329   b->Args({ 28 *  28,  40, 120});
330   /******* Bottleneck 6 *******/
331   /*           M       N    K */
332 //b->Args({ 28 *  28, 120,  40});
333 //b->Args({  1 *   1,  32, 120});
334 //b->Args({  1 *   1, 120,  32});
335 //b->Args({ 28 *  28,  40, 120});
336   /******* Bottleneck 7 *******/
337   /*           M       N    K */
338   b->Args({ 28 *  28, 240,  40});
339   b->Args({ 14 *  14,  80, 240});
340   /******* Bottleneck 8 *******/
341   /*           M       N    K */
342   b->Args({ 14 *  14, 200,  80});
343   b->Args({ 14 *  14,  80, 200});
344   /******* Bottleneck 9 *******/
345   /*           M       N    K */
346   b->Args({ 14 *  14, 184,  80});
347   b->Args({ 14 *  14,  80, 184});
348   /******* Bottleneck 10 ******/
349   /*           M       N    K */
350   b->Args({ 14 *  14, 184,  80});
351   b->Args({ 14 *  14,  80, 184});
352   /******* Bottleneck 11 ******/
353   /*           M       N    K */
354   b->Args({ 14 *  14, 480,  80});
355   b->Args({  1 *   1, 120, 480});
356   b->Args({  1 *   1, 480, 120});
357   b->Args({ 14 *  14, 112, 480});
358   /******* Bottleneck 12 ******/
359   /*           M       N    K */
360   b->Args({ 14 *  14, 672, 112});
361   b->Args({  1 *   1, 168, 672});
362   b->Args({  1 *   1, 672, 168});
363   b->Args({ 14 *  14, 112, 672});
364   /******* Bottleneck 13 ******/
365   /*           M       N    K */
366 //b->Args({ 14 *  14, 672, 112});
367 //b->Args({  1 *   1, 168, 672});
368 //b->Args({  1 *   1, 672, 168});
369   b->Args({  7 *   7, 160, 672});
370   /******* Bottleneck 14 ******/
371   /*           M       N    K */
372   b->Args({  7 *   7, 960, 160});
373   b->Args({  1 *   1, 240, 960});
374   b->Args({  1 *   1, 960, 240});
375   b->Args({  7 *   7, 160, 960});
376   /******* Bottleneck 15 ******/
377   /*           M       N    K */
378 //b->Args({  7 *   7, 960, 160});
379 //b->Args({  1 *   1, 240, 960});
380 //b->Args({  1 *   1, 960, 240});
381 //b->Args({  7 *   7, 160, 960});
382   /******** Last Stage  *******/
383   /*           M       N    K */
384 //b->Args({  7 *   7, 960, 160});
385 }
386