xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/bench/convolution.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <algorithm>
10 #include <cfloat>
11 #include <chrono>
12 #include <cmath>
13 #include <functional>
14 #include <iostream>
15 #include <random>
16 #include <vector>
17 
18 #include <pytorch_qnnpack.h>
19 
20 #include <benchmark/benchmark.h>
21 
convolution_q8(benchmark::State & state,const char * net,bool per_channel=false)22 static void convolution_q8(benchmark::State& state, const char* net, bool per_channel=false) {
23   const size_t batchSize = state.range(0);
24   const size_t inputHeight = state.range(1);
25   const size_t inputWidth = state.range(2);
26   const size_t kernelHeight = state.range(3);
27   const size_t kernelWidth = state.range(4);
28   const size_t subsampling = state.range(5);
29   const size_t dilation = state.range(6);
30   const size_t groups = state.range(7);
31   const size_t groupInputChannels = state.range(8);
32   const size_t groupOutputChannels = state.range(9);
33 
34   std::random_device randomDevice;
35   auto rng = std::mt19937(randomDevice());
36   auto s32rng =
37       std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
38   auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
39 
40   const size_t outputPixelStride = groups * groupOutputChannels;
41   const size_t inputPixelStride = groups * groupInputChannels;
42   const size_t effectiveKernelHeight = (kernelHeight - 1) * dilation + 1;
43   const size_t effectiveKernelWidth = (kernelWidth - 1) * dilation + 1;
44   const size_t paddingWidth = effectiveKernelWidth / 2;
45   const size_t paddingHeight = effectiveKernelHeight / 2;
46   const size_t outputHeight =
47       (inputHeight + paddingHeight * 2 - effectiveKernelHeight) / subsampling +
48       1;
49   const size_t outputWidth =
50       (inputWidth + paddingWidth * 2 - effectiveKernelWidth) / subsampling + 1;
51 
52   std::vector<uint8_t> input(
53       batchSize * inputHeight * inputWidth * inputPixelStride);
54   std::generate(input.begin(), input.end(), std::ref(u8rng));
55   std::vector<uint8_t> kernel(
56       groups * groupOutputChannels * kernelHeight * kernelWidth *
57       groupInputChannels);
58   std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
59   std::vector<int32_t> bias(groups * groupOutputChannels);
60   std::generate(bias.begin(), bias.end(), std::ref(s32rng));
61   std::vector<uint8_t> output(
62       batchSize * outputHeight * outputWidth * outputPixelStride);
63 
64   pytorch_qnnp_status status = pytorch_qnnp_initialize();
65   if (status != pytorch_qnnp_status_success) {
66     state.SkipWithError("failed to initialize QNNPACK");
67   }
68 
69   pytorch_qnnp_operator_t convolutionObject = nullptr;
70   size_t num_zero_points_padded =
71     ((groups * groupOutputChannels + 7) / 8) * 8;
72   std::vector<uint8_t> kernel_zero_points(num_zero_points_padded, 127);
73   std::vector<float> requantization_scale(
74       num_zero_points_padded, 0.5 * 0.5 / 0.5);
75   status = pytorch_qnnp_create_convolution2d_nhwc_q8(
76       paddingHeight,
77       paddingWidth,
78       kernelHeight,
79       kernelWidth,
80       subsampling,
81       subsampling,
82       dilation,
83       dilation,
84       groups,
85       groupInputChannels,
86       groupOutputChannels,
87       127,
88       kernel_zero_points.data(),
89       kernel.data(),
90       bias.data(),
91       127,
92       0,
93       255,
94       0 /* flags */,
95       requantization_scale.data(),
96       per_channel,
97       &convolutionObject);
98   if (status != pytorch_qnnp_status_success) {
99     state.SkipWithError("failed to create Convolution operator");
100   }
101 
102   status = pytorch_qnnp_setup_convolution2d_nhwc_q8(
103       convolutionObject,
104       batchSize,
105       inputHeight,
106       inputWidth,
107       input.data(),
108       inputPixelStride,
109       output.data(),
110       outputPixelStride,
111       nullptr /* thread pool */);
112   if (status != pytorch_qnnp_status_success) {
113     state.SkipWithError("failed to setup Convolution operator");
114   }
115 
116   for (auto _ : state) {
117     pytorch_qnnp_run_operator(convolutionObject, nullptr /* thread pool */);
118   }
119 
120   status = pytorch_qnnp_delete_operator(convolutionObject);
121   if (status != pytorch_qnnp_status_success) {
122     state.SkipWithError("failed to delete Convolution operator");
123   }
124   convolutionObject = nullptr;
125 
126   state.SetItemsProcessed(
127       uint64_t(state.iterations()) * 2 * batchSize * outputHeight *
128       outputWidth * groups * groupInputChannels * groupOutputChannels *
129       kernelHeight * kernelWidth);
130 }
131 
132 /* ShuffleNet v1 with 1 group */
ShuffleNetV1G1(benchmark::internal::Benchmark * b)133 static void ShuffleNetV1G1(benchmark::internal::Benchmark* b) {
134   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
135 
136   /*********************** Conv 1 **********************/
137   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
138   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
139   /*************** Stage 2: stride-2 unit **************/
140   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
141   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 36});
142   b->Args({1, 56, 56, 3, 3, 2, 1, 36, 1, 1});
143   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 36, 120});
144   /*************** Stage 2: stride-1 units *************/
145   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
146   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 36});
147   b->Args({1, 28, 28, 3, 3, 2, 1, 36, 1, 1});
148   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 36, 144});
149   /*************** Stage 3: stride-2 unit **************/
150   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
151   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 72});
152   b->Args({1, 28, 28, 3, 3, 2, 1, 72, 1, 1});
153   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 72, 144});
154   /*************** Stage 3: stride-1 units *************/
155   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
156   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 288, 72});
157   b->Args({1, 14, 14, 3, 3, 2, 1, 72, 1, 1});
158   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 72, 288});
159   /*************** Stage 4: stride-2 unit **************/
160   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
161   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 288, 144});
162   b->Args({1, 14, 14, 3, 3, 2, 1, 144, 1, 1});
163   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 144, 288});
164   /*************** Stage 4: stride-1 units *************/
165   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
166   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 576, 144});
167   b->Args({1, 7, 7, 3, 3, 2, 1, 144, 1, 1});
168   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 144, 576});
169 }
170 
171 /* ShuffleNet v1 with 2 groups */
ShuffleNetV1G2(benchmark::internal::Benchmark * b)172 static void ShuffleNetV1G2(benchmark::internal::Benchmark* b) {
173   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
174 
175   /*********************** Conv 1 **********************/
176   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
177   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
178   /*************** Stage 2: stride-2 unit **************/
179   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
180   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 50});
181   b->Args({1, 56, 56, 3, 3, 2, 1, 50, 1, 1});
182   b->Args({1, 28, 28, 1, 1, 1, 1, 2, 25, 88});
183   /*************** Stage 2: stride-1 units *************/
184   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
185   b->Args({1, 28, 28, 1, 1, 1, 1, 2, 100, 25});
186   b->Args({1, 28, 28, 3, 3, 2, 1, 50, 1, 1});
187   b->Args({1, 28, 28, 1, 1, 1, 1, 2, 25, 100});
188   /*************** Stage 3: stride-2 unit **************/
189   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
190   b->Args({1, 28, 28, 1, 1, 1, 1, 2, 100, 50});
191   b->Args({1, 28, 28, 3, 3, 2, 1, 100, 1, 1});
192   b->Args({1, 14, 14, 1, 1, 1, 1, 2, 50, 100});
193   /*************** Stage 3: stride-1 units *************/
194   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
195   b->Args({1, 14, 14, 1, 1, 1, 1, 2, 200, 50});
196   b->Args({1, 14, 14, 3, 3, 2, 1, 100, 1, 1});
197   b->Args({1, 14, 14, 1, 1, 1, 1, 2, 50, 200});
198   /*************** Stage 4: stride-2 unit **************/
199   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
200   b->Args({1, 14, 14, 1, 1, 1, 1, 2, 200, 100});
201   b->Args({1, 14, 14, 3, 3, 2, 1, 200, 1, 1});
202   b->Args({1, 7, 7, 1, 1, 1, 1, 2, 100, 200});
203   /*************** Stage 4: stride-1 units *************/
204   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
205   b->Args({1, 7, 7, 1, 1, 1, 1, 2, 400, 100});
206   b->Args({1, 7, 7, 3, 3, 2, 1, 200, 1, 1});
207   b->Args({1, 7, 7, 1, 1, 1, 1, 2, 100, 400});
208 }
209 
210 /* ShuffleNet v1 with 3 groups */
ShuffleNetV1G3(benchmark::internal::Benchmark * b)211 static void ShuffleNetV1G3(benchmark::internal::Benchmark* b) {
212   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
213 
214   /*********************** Conv 1 **********************/
215   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
216   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
217   /*************** Stage 2: stride-2 unit **************/
218   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
219   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 60});
220   b->Args({1, 56, 56, 3, 3, 2, 1, 60, 1, 1});
221   b->Args({1, 28, 28, 1, 1, 1, 1, 3, 20, 72});
222   /*************** Stage 2: stride-1 units *************/
223   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
224   b->Args({1, 28, 28, 1, 1, 1, 1, 3, 80, 20});
225   b->Args({1, 28, 28, 3, 3, 2, 1, 60, 1, 1});
226   b->Args({1, 28, 28, 1, 1, 1, 1, 3, 20, 80});
227   /*************** Stage 3: stride-2 unit **************/
228   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
229   b->Args({1, 28, 28, 1, 1, 1, 1, 3, 80, 40});
230   b->Args({1, 28, 28, 3, 3, 2, 1, 120, 1, 1});
231   b->Args({1, 14, 14, 1, 1, 1, 1, 3, 40, 80});
232   /*************** Stage 3: stride-1 units *************/
233   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
234   b->Args({1, 14, 14, 1, 1, 1, 1, 3, 160, 40});
235   b->Args({1, 14, 14, 3, 3, 2, 1, 120, 1, 1});
236   b->Args({1, 14, 14, 1, 1, 1, 1, 3, 40, 160});
237   /*************** Stage 4: stride-2 unit **************/
238   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
239   b->Args({1, 14, 14, 1, 1, 1, 1, 3, 160, 80});
240   b->Args({1, 14, 14, 3, 3, 2, 1, 240, 1, 1});
241   b->Args({1, 7, 7, 1, 1, 1, 1, 3, 80, 160});
242   /*************** Stage 4: stride-1 units *************/
243   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
244   b->Args({1, 7, 7, 1, 1, 1, 1, 3, 320, 80});
245   b->Args({1, 7, 7, 3, 3, 2, 1, 240, 1, 1});
246   b->Args({1, 7, 7, 1, 1, 1, 1, 3, 80, 320});
247 }
248 
249 /* ShuffleNet v1 with 4 groups */
ShuffleNetV1G4(benchmark::internal::Benchmark * b)250 static void ShuffleNetV1G4(benchmark::internal::Benchmark* b) {
251   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
252 
253   /*********************** Conv 1 **********************/
254   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
255   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
256   /*************** Stage 2: stride-2 unit **************/
257   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
258   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 68});
259   b->Args({1, 56, 56, 3, 3, 2, 1, 68, 1, 1});
260   b->Args({1, 28, 28, 1, 1, 1, 1, 4, 17, 62});
261   /*************** Stage 2: stride-1 units *************/
262   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
263   b->Args({1, 28, 28, 1, 1, 1, 1, 4, 68, 17});
264   b->Args({1, 28, 28, 3, 3, 2, 1, 68, 1, 1});
265   b->Args({1, 28, 28, 1, 1, 1, 1, 4, 17, 68});
266   /*************** Stage 3: stride-2 unit **************/
267   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
268   b->Args({1, 28, 28, 1, 1, 1, 1, 4, 68, 34});
269   b->Args({1, 28, 28, 3, 3, 2, 1, 136, 1, 1});
270   b->Args({1, 14, 14, 1, 1, 1, 1, 4, 34, 68});
271   /*************** Stage 3: stride-1 units *************/
272   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
273   b->Args({1, 14, 14, 1, 1, 1, 1, 4, 136, 34});
274   b->Args({1, 14, 14, 3, 3, 2, 1, 136, 1, 1});
275   b->Args({1, 14, 14, 1, 1, 1, 1, 4, 34, 136});
276   /*************** Stage 4: stride-2 unit **************/
277   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
278   b->Args({1, 14, 14, 1, 1, 1, 1, 4, 136, 68});
279   b->Args({1, 14, 14, 3, 3, 2, 1, 272, 1, 1});
280   b->Args({1, 7, 7, 1, 1, 1, 1, 4, 68, 136});
281   /*************** Stage 4: stride-1 units *************/
282   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
283   b->Args({1, 7, 7, 1, 1, 1, 1, 4, 272, 68});
284   b->Args({1, 7, 7, 3, 3, 2, 1, 272, 1, 1});
285   b->Args({1, 7, 7, 1, 1, 1, 1, 4, 68, 272});
286 }
287 
288 /* ShuffleNet v1 with 8 groups */
ShuffleNetV1G8(benchmark::internal::Benchmark * b)289 static void ShuffleNetV1G8(benchmark::internal::Benchmark* b) {
290   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
291 
292   /*********************** Conv 1 **********************/
293   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
294   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
295   /*************** Stage 2: stride-2 unit **************/
296   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
297   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 96});
298   b->Args({1, 56, 56, 3, 3, 2, 1, 96, 1, 1});
299   b->Args({1, 28, 28, 1, 1, 1, 1, 8, 12, 45});
300   /*************** Stage 2: stride-1 units *************/
301   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
302   b->Args({1, 28, 28, 1, 1, 1, 1, 8, 48, 12});
303   b->Args({1, 28, 28, 3, 3, 2, 1, 96, 1, 1});
304   b->Args({1, 28, 28, 1, 1, 1, 1, 8, 12, 48});
305   /*************** Stage 3: stride-2 unit **************/
306   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
307   b->Args({1, 28, 28, 1, 1, 1, 1, 8, 48, 24});
308   b->Args({1, 28, 28, 3, 3, 2, 1, 192, 1, 1});
309   b->Args({1, 14, 14, 1, 1, 1, 1, 8, 24, 48});
310   /*************** Stage 3: stride-1 units *************/
311   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
312   b->Args({1, 14, 14, 1, 1, 1, 1, 8, 96, 24});
313   b->Args({1, 14, 14, 3, 3, 2, 1, 192, 1, 1});
314   b->Args({1, 14, 14, 1, 1, 1, 1, 8, 24, 96});
315   /*************** Stage 4: stride-2 unit **************/
316   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
317   b->Args({1, 14, 14, 1, 1, 1, 1, 8, 96, 48});
318   b->Args({1, 14, 14, 3, 3, 2, 1, 384, 1, 1});
319   b->Args({1, 7, 7, 1, 1, 1, 1, 8, 48, 96});
320   /*************** Stage 4: stride-1 units *************/
321   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
322   b->Args({1, 7, 7, 1, 1, 1, 1, 8, 192, 48});
323   b->Args({1, 7, 7, 3, 3, 2, 1, 384, 1, 1});
324   b->Args({1, 7, 7, 1, 1, 1, 1, 8, 48, 192});
325 }
326 
327 /* ShuffleNet v2 (0.5X scale) */
ShuffleNetV2X05(benchmark::internal::Benchmark * b)328 static void ShuffleNetV2X05(benchmark::internal::Benchmark* b) {
329   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
330 
331   /*********************** Conv 1 **********************/
332   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
333   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
334   /********************** Stage 2 **********************/
335   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
336   b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1});
337   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 24});
338   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 24});
339   b->Args({1, 28, 28, 3, 3, 1, 1, 24, 1, 1});
340   /********************** Stage 3 **********************/
341   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
342   b->Args({1, 28, 28, 3, 3, 2, 1, 48, 1, 1});
343   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 48, 48});
344   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 48, 48});
345   b->Args({1, 14, 14, 3, 3, 1, 1, 48, 1, 1});
346   /********************** Stage 4 **********************/
347   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
348   b->Args({1, 14, 14, 3, 3, 2, 1, 96, 1, 1});
349   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 96, 96});
350   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 96});
351   b->Args({1, 7, 7, 3, 3, 1, 1, 96, 1, 1});
352   /*********************** Conv 5 **********************/
353   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
354   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 192, 1024});
355 }
356 
357 /* ShuffleNet v2 (1.0X scale) */
ShuffleNetV2X10(benchmark::internal::Benchmark * b)358 static void ShuffleNetV2X10(benchmark::internal::Benchmark* b) {
359   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
360 
361   /*********************** Conv 1 **********************/
362   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
363   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
364   /********************** Stage 2 **********************/
365   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
366   b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1});
367   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 58});
368   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 58});
369   b->Args({1, 56, 56, 3, 3, 2, 1, 58, 1, 1});
370   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 58, 58});
371   b->Args({1, 28, 28, 3, 3, 1, 1, 58, 1, 1});
372   /********************** Stage 3 **********************/
373   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
374   b->Args({1, 28, 28, 3, 3, 2, 1, 116, 1, 1});
375   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 116, 116});
376   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 116, 116});
377   b->Args({1, 14, 14, 3, 3, 1, 1, 116, 1, 1});
378   /********************** Stage 4 **********************/
379   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
380   b->Args({1, 14, 14, 3, 3, 2, 1, 232, 1, 1});
381   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 232, 232});
382   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 232, 232});
383   b->Args({1, 7, 7, 3, 3, 1, 1, 232, 1, 1});
384   /*********************** Conv 5 **********************/
385   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
386   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 464, 1024});
387 }
388 
389 /* ShuffleNet v2 (1.5X scale) */
ShuffleNetV2X15(benchmark::internal::Benchmark * b)390 static void ShuffleNetV2X15(benchmark::internal::Benchmark* b) {
391   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
392 
393   /*********************** Conv 1 **********************/
394   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
395   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
396   /********************** Stage 2 **********************/
397   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
398   b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1});
399   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 88});
400   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 88});
401   b->Args({1, 56, 56, 3, 3, 2, 1, 88, 1, 1});
402   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 88, 88});
403   b->Args({1, 28, 28, 3, 3, 1, 1, 88, 1, 1});
404   /********************** Stage 3 **********************/
405   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
406   b->Args({1, 28, 28, 3, 3, 2, 1, 176, 1, 1});
407   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 176, 176});
408   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 176, 176});
409   b->Args({1, 14, 14, 3, 3, 1, 1, 176, 1, 1});
410   /********************** Stage 4 **********************/
411   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
412   b->Args({1, 14, 14, 3, 3, 2, 1, 352, 1, 1});
413   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 352, 352});
414   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 352, 352});
415   b->Args({1, 7, 7, 3, 3, 1, 1, 352, 1, 1});
416   /*********************** Conv 5 **********************/
417   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
418   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 704, 1024});
419 }
420 
421 /* ShuffleNet v2 (2.0X scale) */
ShuffleNetV2X20(benchmark::internal::Benchmark * b)422 static void ShuffleNetV2X20(benchmark::internal::Benchmark* b) {
423   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
424 
425   /*********************** Conv 1 **********************/
426   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
427   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24});
428   /********************** Stage 2 **********************/
429   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
430   b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1});
431   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 122});
432   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 122});
433   b->Args({1, 56, 56, 3, 3, 2, 1, 122, 1, 1});
434   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 122, 122});
435   b->Args({1, 28, 28, 3, 3, 1, 1, 122, 1, 1});
436   /********************** Stage 3 **********************/
437   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
438   b->Args({1, 28, 28, 3, 3, 2, 1, 244, 1, 1});
439   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 244, 244});
440   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 244, 244});
441   b->Args({1, 14, 14, 3, 3, 1, 1, 244, 1, 1});
442   /********************** Stage 4 **********************/
443   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
444   b->Args({1, 14, 14, 3, 3, 2, 1, 488, 1, 1});
445   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 488, 488});
446   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 488, 488});
447   b->Args({1, 7, 7, 3, 3, 1, 1, 488, 1, 1});
448   /*********************** Conv 5 **********************/
449   /*       N   H    W   KH  KW  S  D   G   GCin  GCout */
450   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 976, 2048});
451 }
452 
MobileNetV1(benchmark::internal::Benchmark * b)453 static void MobileNetV1(benchmark::internal::Benchmark* b) {
454   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
455 
456   /*       N   H    W   KH  KW  S  D    G   GCin  GCout */
457   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 32});
458   b->Args({1, 112, 112, 3, 3, 1, 1, 32, 1, 1});
459   b->Args({1, 112, 112, 1, 1, 1, 1, 1, 32, 64});
460   b->Args({1, 112, 112, 3, 3, 2, 1, 64, 1, 1});
461   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 128});
462   b->Args({1, 56, 56, 3, 3, 1, 1, 128, 1, 1});
463   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 128, 128});
464   b->Args({1, 56, 56, 3, 3, 2, 1, 128, 1, 1});
465   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 128, 256});
466   b->Args({1, 28, 28, 3, 3, 1, 1, 256, 1, 1});
467   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 256, 256});
468   b->Args({1, 28, 28, 3, 3, 2, 1, 256, 1, 1});
469   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 256, 512});
470   b->Args({1, 14, 14, 3, 3, 1, 1, 512, 1, 1});
471   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 512, 512});
472   b->Args({1, 14, 14, 3, 3, 2, 1, 512, 1, 1});
473   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 512, 1024});
474   b->Args({1, 7, 7, 3, 3, 1, 1, 1024, 1, 1});
475   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 1024, 1024});
476 }
477 
MobileNetV2(benchmark::internal::Benchmark * b)478 static void MobileNetV2(benchmark::internal::Benchmark* b) {
479   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
480 
481   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
482   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 32});
483 
484   /******************** Bottleneck 1 *******************/
485   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
486   b->Args({1, 112, 112, 3, 3, 1, 1, 32, 1, 1});
487   b->Args({1, 112, 112, 1, 1, 1, 1, 1, 32, 16});
488 
489   /******************** Bottleneck 2 *******************/
490   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
491   b->Args({1, 112, 112, 1, 1, 1, 1, 1, 16, 96});
492   b->Args({1, 112, 112, 3, 3, 2, 1, 96, 1, 1});
493   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 96, 24});
494   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 144});
495   b->Args({1, 56, 56, 3, 3, 1, 1, 144, 1, 1});
496   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 144, 24});
497 
498   /******************** Bottleneck 3 *******************/
499   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
500   // b->Args({1,  56,  56,  1,  1, 1, 1,   1,   24,  144});
501   b->Args({1, 56, 56, 3, 3, 2, 1, 144, 1, 1});
502   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 32});
503   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 32, 192});
504   b->Args({1, 28, 28, 3, 3, 1, 1, 192, 1, 1});
505   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 192, 32});
506   // b->Args({1,  28,  28,  1,  1, 1, 1,   1,   32,  192});
507   // b->Args({1,  28,  28,  3,  3, 1, 1, 192,    1,    1});
508   // b->Args({1,  28,  28,  1,  1, 1, 1,   1,  192,   32});
509 
510   /******************** Bottleneck 4 *******************/
511   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
512   // b->Args({1,  28,  28,  1,  1, 1, 1,   1,   32,  192});
513   b->Args({1, 28, 28, 3, 3, 2, 1, 192, 1, 1});
514   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 192, 64});
515   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 64, 384});
516   b->Args({1, 14, 14, 3, 3, 1, 1, 384, 1, 1});
517   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 64});
518   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,   64,  384});
519   // b->Args({1,  14,  14,  3,  3, 1, 1, 384,    1,    1});
520   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,  384,   64});
521   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,   64,  384});
522   // b->Args({1,  14,  14,  3,  3, 1, 1, 384,    1,    1});
523   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,  384,   64});
524 
525   /******************** Bottleneck 5 *******************/
526   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
527   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,   64,  384});
528   // b->Args({1,  14,  14,  3,  3, 1, 1, 384,    1,    1});
529   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 96});
530   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 576});
531   b->Args({1, 14, 14, 3, 3, 1, 1, 576, 1, 1});
532   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 576, 96});
533   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,   96,  576});
534   // b->Args({1,  14,  14,  3,  3, 1, 1, 576,    1,    1});
535   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,  576,   96});
536 
537   /******************** Bottleneck 6 *******************/
538   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
539   // b->Args({1,  14,  14,  1,  1, 1, 1,   1,   96,  576});
540   b->Args({1, 14, 14, 3, 3, 2, 1, 576, 1, 1});
541   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 576, 160});
542   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 160, 960});
543   b->Args({1, 7, 7, 3, 3, 1, 1, 960, 1, 1});
544   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 960, 160});
545   // b->Args({1,   7,   7,  1,  1, 1, 1,   1,  160,  960});
546   // b->Args({1,   7,   7,  3,  3, 1, 1, 960,    1,    1});
547   // b->Args({1,   7,   7,  1,  1, 1, 1,   1,  960,  160});
548 
549   /******************** Bottleneck 7 *******************/
550   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
551   // b->Args({1,   7,   7,  1,  1, 1, 1,   1,  160,  960});
552   // b->Args({1,   7,   7,  3,  3, 1, 1, 960,    1,    1});
553   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 960, 320});
554 
555   /**************** Pre-pooling Conv2D *****************/
556   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
557   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 320, 1280});
558   /**************** Post-pooling Conv2D ****************/
559   /*       N   H    W   KH  KW  S  D    G  GCin  GCout */
560   b->Args({1, 1, 1, 1, 1, 1, 1, 1, 1280, 1000});
561 }
562 
563 /* SqueezeNet 1.0 */
SqueezeNetV10(benchmark::internal::Benchmark * b)564 static void SqueezeNetV10(benchmark::internal::Benchmark* b) {
565   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
566 
567   /********************** Conv 1 *********************/
568   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
569   b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 96});
570   /********************** Fire 2 *********************/
571   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
572   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 96, 16});
573   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64});
574   b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64});
575   /********************** Fire 3 *********************/
576   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
577   b->Args({1, 56, 55, 1, 1, 1, 1, 1, 128, 16});
578   /*b->Args({1,  55,  55,  1,  1, 1, 1, 1,   16,   64});*/
579   /*b->Args({1,  55,  55,  3,  3, 1, 1, 1,   16,   64});*/
580   /********************** Fire 4 *********************/
581   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
582   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 128, 32});
583   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 32, 128});
584   b->Args({1, 55, 55, 3, 3, 1, 1, 1, 32, 128});
585   /********************** Fire 5 *********************/
586   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
587   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 32});
588   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 32, 128});
589   b->Args({1, 27, 27, 3, 3, 1, 1, 1, 32, 128});
590   /********************** Fire 6 *********************/
591   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
592   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 48});
593   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 48, 192});
594   b->Args({1, 27, 27, 3, 3, 1, 1, 1, 48, 192});
595   /********************** Fire 7 *********************/
596   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
597   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 384, 48});
598   /*b->Args({1,  27,  27,  1,  1, 1, 1, 1,   48,  192});*/
599   /*b->Args({1,  27,  27,  3,  3, 1, 1, 1,   48,  192});*/
600   /********************** Fire 8 *********************/
601   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
602   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 384, 64});
603   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 64, 256});
604   b->Args({1, 27, 27, 3, 3, 1, 1, 1, 64, 256});
605   /********************** Fire 9 *********************/
606   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
607   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 64});
608   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 64, 256});
609   b->Args({1, 13, 13, 3, 3, 1, 1, 1, 64, 256});
610   /********************* Conv 10 *********************/
611   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
612   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 1000});
613 }
614 
615 /* SqueezeNet 1.1 */
SqueezeNetV11(benchmark::internal::Benchmark * b)616 static void SqueezeNetV11(benchmark::internal::Benchmark* b) {
617   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
618 
619   /********************** Conv 1 *********************/
620   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
621   b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 64});
622   /********************** Fire 2 *********************/
623   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
624   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 64, 16});
625   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64});
626   b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64});
627   /********************** Fire 3 *********************/
628   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
629   b->Args({1, 55, 55, 1, 1, 1, 1, 1, 128, 16});
630   /*b->Args({1,  55,  55,  1,  1, 1, 1, 1,   16,   64});*/
631   /*b->Args({1,  55,  55,  3,  3, 1, 1, 1,   16,   64});*/
632   /********************** Fire 4 *********************/
633   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
634   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 128, 32});
635   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 32, 128});
636   b->Args({1, 27, 27, 3, 3, 1, 1, 1, 32, 128});
637   /********************** Fire 5 *********************/
638   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
639   b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 32});
640   /*b->Args({1,  27,  27,  1,  1, 1, 1, 1,   32,  128});*/
641   /*b->Args({1,  27,  27,  3,  3, 1, 1, 1,   32,  128});*/
642   /********************** Fire 6 *********************/
643   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
644   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 256, 48});
645   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 48, 192});
646   b->Args({1, 13, 13, 3, 3, 1, 1, 1, 48, 192});
647   /********************** Fire 7 *********************/
648   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
649   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 384, 48});
650   /*b->Args({1,  13,  13,  1,  1, 1, 1, 1,   48,  192});*/
651   /*b->Args({1,  13,  13,  3,  3, 1, 1, 1,   48,  192});*/
652   /********************** Fire 8 *********************/
653   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
654   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 384, 64});
655   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 64, 256});
656   b->Args({1, 13, 13, 3, 3, 1, 1, 1, 64, 256});
657   /********************** Fire 9 *********************/
658   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
659   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 64});
660   /*b->Args({1,  13,  13,  1,  1, 1, 1, 1,   64,  256});*/
661   /*b->Args({1,  13,  13,  3,  3, 1, 1, 1,   64,  256});*/
662   /********************* Conv 10 *********************/
663   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
664   b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 1000});
665 }
666 
ResNet18(benchmark::internal::Benchmark * b)667 static void ResNet18(benchmark::internal::Benchmark* b) {
668   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
669 
670   /********************* Conv 1 *********************/
671   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
672   b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 64});
673   /******************** Conv 2.X ********************/
674   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
675   b->Args({1, 56, 56, 3, 3, 1, 1, 1, 64, 64});
676   /******************** Conv 3.X ********************/
677   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
678   b->Args({1, 56, 56, 3, 3, 2, 1, 1, 64, 128});
679   b->Args({1, 28, 28, 3, 3, 1, 1, 1, 128, 128});
680   b->Args({1, 56, 56, 1, 1, 2, 1, 1, 64, 128});
681   /******************** Conv 4.X ********************/
682   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
683   b->Args({1, 28, 28, 3, 3, 2, 1, 1, 128, 256});
684   b->Args({1, 14, 14, 3, 3, 1, 1, 1, 256, 256});
685   b->Args({1, 28, 28, 1, 1, 2, 1, 1, 128, 256});
686   /******************** Conv 5.X ********************/
687   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
688   b->Args({1, 14, 14, 3, 3, 2, 1, 1, 256, 512});
689   b->Args({1, 7, 7, 3, 3, 1, 1, 1, 512, 512});
690   b->Args({1, 14, 14, 1, 1, 2, 1, 1, 256, 512});
691 }
692 
ResNet50(benchmark::internal::Benchmark * b)693 static void ResNet50(benchmark::internal::Benchmark* b) {
694   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
695 
696   /********************* Conv 1 *********************/
697   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
698   b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 64});
699   /******************** Conv 2.1 ********************/
700   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
701   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 64});
702   b->Args({1, 56, 56, 3, 3, 1, 1, 1, 64, 64});
703   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 256});
704   /*b->Args({1,  56,  56,  1,  1, 1, 1, 1,   64,  256});*/
705   /******************** Conv 2.X ********************/
706   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
707   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 64});
708   /*b->Args({1,  56,  56,  3,  3, 1, 1, 1,   64,   64});*/
709   /*b->Args({1,  56,  56,  1,  1, 1, 1, 1,   64,  256});*/
710   /******************** Conv 3.1 ********************/
711   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
712   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 128});
713   b->Args({1, 56, 56, 3, 3, 2, 1, 1, 128, 128});
714   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 128, 512});
715   b->Args({1, 56, 56, 1, 1, 2, 1, 1, 256, 512});
716   /******************** Conv 3.X ********************/
717   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
718   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 128});
719   b->Args({1, 28, 28, 3, 3, 1, 1, 1, 128, 128});
720   /*b->Args({1,  28,  28,  1,  1, 1, 1, 1,  128,  512});*/
721   /******************** Conv 4.1 ********************/
722   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
723   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 256});
724   b->Args({1, 28, 28, 3, 3, 2, 1, 1, 256, 256});
725   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 256, 1024});
726   b->Args({1, 28, 28, 1, 1, 2, 1, 1, 512, 1024});
727   /******************** Conv 4.X ********************/
728   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
729   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 1024, 256});
730   b->Args({1, 14, 14, 3, 3, 1, 1, 1, 256, 256});
731   /*b->Args({1,  14,  14,  1,  1, 1, 1, 1,  256, 1024});*/
732   /******************** Conv 5.1 ********************/
733   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
734   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 1024, 512});
735   b->Args({1, 14, 14, 3, 3, 2, 1, 1, 512, 512});
736   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 512, 2048});
737   b->Args({1, 14, 14, 1, 1, 2, 1, 1, 1024, 2048});
738   /******************** Conv 5.X ********************/
739   /*       N   H    W   KH  KW  S  D  G GCin  GCout */
740   b->Args({1, 7, 7, 1, 1, 1, 1, 1, 2048, 512});
741   b->Args({1, 7, 7, 3, 3, 1, 1, 1, 512, 512});
742   /*b->Args({1,   7,   7,  1,  1, 1, 1, 1,  512, 2048});*/
743 }
744 
VGG(benchmark::internal::Benchmark * b)745 static void VGG(benchmark::internal::Benchmark* b) {
746   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
747 
748   /********************* Conv 1.1 ********************/
749   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
750   b->Args({1, 224, 224, 3, 3, 1, 1, 1, 3, 64});
751   /********************* Conv 1.2 ********************/
752   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
753   b->Args({1, 224, 224, 3, 3, 1, 1, 1, 64, 64});
754 
755   /********************* Conv 2.1 ********************/
756   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
757   b->Args({1, 112, 112, 3, 3, 1, 1, 1, 64, 128});
758   /********************* Conv 2.2 ********************/
759   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
760   b->Args({1, 112, 112, 3, 3, 1, 1, 1, 128, 128});
761 
762   /********************* Conv 3.1 ********************/
763   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
764   b->Args({1, 56, 56, 3, 3, 1, 1, 1, 128, 256});
765   /********************* Conv 3.2 ********************/
766   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
767   b->Args({1, 56, 56, 3, 3, 1, 1, 1, 256, 256});
768   /********************* Conv 3.3 ********************/
769   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
770   b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 256});
771 
772   /********************* Conv 4.1 ********************/
773   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
774   b->Args({1, 28, 28, 3, 3, 1, 1, 1, 256, 512});
775   /********************* Conv 4.2 ********************/
776   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
777   b->Args({1, 28, 28, 3, 3, 1, 1, 1, 512, 512});
778   /********************* Conv 4.3 ********************/
779   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
780   b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 512});
781 
782   /********************* Conv 5.X ********************/
783   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
784   b->Args({1, 14, 14, 3, 3, 1, 1, 1, 512, 512});
785   /********************* Conv 5.3 ********************/
786   /*       N   H    W   KH  KW  S  D  G  GCin  GCout */
787   b->Args({1, 14, 14, 1, 1, 1, 1, 1, 512, 512});
788 }
789 
DWConv3x3(benchmark::internal::Benchmark * b)790 static void DWConv3x3(benchmark::internal::Benchmark* b) {
791   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
792 
793   /********************** 96 x 96 *********************/
794   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
795   b->Args({1, 96, 96, 3, 3, 1, 1, 512, 1, 1});
796   b->Args({1, 96, 96, 3, 3, 1, 1, 256, 1, 1});
797   b->Args({1, 96, 96, 3, 3, 1, 1, 128, 1, 1});
798   b->Args({1, 96, 96, 3, 3, 1, 1, 64, 1, 1});
799   b->Args({1, 96, 96, 3, 3, 1, 1, 48, 1, 1});
800   b->Args({1, 96, 96, 3, 3, 1, 1, 32, 1, 1});
801   b->Args({1, 96, 96, 3, 3, 1, 1, 24, 1, 1});
802   b->Args({1, 96, 96, 3, 3, 1, 1, 16, 1, 1});
803   /********************** 32 x 32 *********************/
804   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
805   b->Args({1, 32, 32, 3, 3, 1, 1, 768, 1, 1});
806   b->Args({1, 32, 32, 3, 3, 1, 1, 512, 1, 1});
807   b->Args({1, 32, 32, 3, 3, 1, 1, 256, 1, 1});
808   b->Args({1, 32, 32, 3, 3, 1, 1, 128, 1, 1});
809   b->Args({1, 32, 32, 3, 3, 1, 1, 64, 1, 1});
810   b->Args({1, 32, 32, 3, 3, 1, 1, 48, 1, 1});
811   b->Args({1, 32, 32, 3, 3, 1, 1, 32, 1, 1});
812   b->Args({1, 32, 32, 3, 3, 1, 1, 24, 1, 1});
813   b->Args({1, 32, 32, 3, 3, 1, 1, 16, 1, 1});
814   /********************** 17 x 17 *********************/
815   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
816   b->Args({1, 17, 17, 3, 3, 1, 1, 1024, 1, 1});
817   b->Args({1, 17, 17, 3, 3, 1, 1, 768, 1, 1});
818   b->Args({1, 17, 17, 3, 3, 1, 1, 512, 1, 1});
819   b->Args({1, 17, 17, 3, 3, 1, 1, 384, 1, 1});
820   b->Args({1, 17, 17, 3, 3, 1, 1, 256, 1, 1});
821   b->Args({1, 17, 17, 3, 3, 1, 1, 128, 1, 1});
822   b->Args({1, 17, 17, 3, 3, 1, 1, 64, 1, 1});
823   b->Args({1, 17, 17, 3, 3, 1, 1, 32, 1, 1});
824   b->Args({1, 17, 17, 3, 3, 1, 1, 16, 1, 1});
825   /********************** 11 x 11 *********************/
826   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
827   b->Args({1, 11, 11, 3, 3, 1, 1, 1024, 1, 1});
828   b->Args({1, 11, 11, 3, 3, 1, 1, 768, 1, 1});
829   b->Args({1, 11, 11, 3, 3, 1, 1, 512, 1, 1});
830   b->Args({1, 11, 11, 3, 3, 1, 1, 384, 1, 1});
831   b->Args({1, 11, 11, 3, 3, 1, 1, 256, 1, 1});
832   b->Args({1, 11, 11, 3, 3, 1, 1, 192, 1, 1});
833   b->Args({1, 11, 11, 3, 3, 1, 1, 128, 1, 1});
834   b->Args({1, 11, 11, 3, 3, 1, 1, 64, 1, 1});
835   b->Args({1, 11, 11, 3, 3, 1, 1, 32, 1, 1});
836   b->Args({1, 11, 11, 3, 3, 1, 1, 16, 1, 1});
837   /*********************** 7 x 7 **********************/
838   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
839   b->Args({1, 7, 7, 3, 3, 1, 1, 1024, 1, 1});
840   b->Args({1, 7, 7, 3, 3, 1, 1, 768, 1, 1});
841   b->Args({1, 7, 7, 3, 3, 1, 1, 512, 1, 1});
842   b->Args({1, 7, 7, 3, 3, 1, 1, 384, 1, 1});
843   b->Args({1, 7, 7, 3, 3, 1, 1, 256, 1, 1});
844   b->Args({1, 7, 7, 3, 3, 1, 1, 128, 1, 1});
845   b->Args({1, 7, 7, 3, 3, 1, 1, 64, 1, 1});
846   b->Args({1, 7, 7, 3, 3, 1, 1, 32, 1, 1});
847   b->Args({1, 7, 7, 3, 3, 1, 1, 16, 1, 1});
848 }
849 
DWConv3x3d2(benchmark::internal::Benchmark * b)850 static void DWConv3x3d2(benchmark::internal::Benchmark* b) {
851   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
852 
853   /********************** 96 x 96 *********************/
854   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
855   b->Args({1, 96, 96, 3, 3, 1, 2, 512, 1, 1});
856   b->Args({1, 96, 96, 3, 3, 1, 2, 256, 1, 1});
857   b->Args({1, 96, 96, 3, 3, 1, 2, 128, 1, 1});
858   b->Args({1, 96, 96, 3, 3, 1, 2, 64, 1, 1});
859   b->Args({1, 96, 96, 3, 3, 1, 2, 48, 1, 1});
860   b->Args({1, 96, 96, 3, 3, 1, 2, 32, 1, 1});
861   b->Args({1, 96, 96, 3, 3, 1, 2, 24, 1, 1});
862   b->Args({1, 96, 96, 3, 3, 1, 2, 16, 1, 1});
863   /********************** 32 x 32 *********************/
864   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
865   b->Args({1, 32, 32, 3, 3, 1, 2, 768, 1, 1});
866   b->Args({1, 32, 32, 3, 3, 1, 2, 512, 1, 1});
867   b->Args({1, 32, 32, 3, 3, 1, 2, 256, 1, 1});
868   b->Args({1, 32, 32, 3, 3, 1, 2, 128, 1, 1});
869   b->Args({1, 32, 32, 3, 3, 1, 2, 64, 1, 1});
870   b->Args({1, 32, 32, 3, 3, 1, 2, 48, 1, 1});
871   b->Args({1, 32, 32, 3, 3, 1, 2, 32, 1, 1});
872   b->Args({1, 32, 32, 3, 3, 1, 2, 24, 1, 1});
873   b->Args({1, 32, 32, 3, 3, 1, 2, 16, 1, 1});
874   /********************** 17 x 17 *********************/
875   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
876   b->Args({1, 17, 17, 3, 3, 1, 2, 1024, 1, 1});
877   b->Args({1, 17, 17, 3, 3, 1, 2, 768, 1, 1});
878   b->Args({1, 17, 17, 3, 3, 1, 2, 512, 1, 1});
879   b->Args({1, 17, 17, 3, 3, 1, 2, 384, 1, 1});
880   b->Args({1, 17, 17, 3, 3, 1, 2, 256, 1, 1});
881   b->Args({1, 17, 17, 3, 3, 1, 2, 128, 1, 1});
882   b->Args({1, 17, 17, 3, 3, 1, 2, 64, 1, 1});
883   b->Args({1, 17, 17, 3, 3, 1, 2, 32, 1, 1});
884   b->Args({1, 17, 17, 3, 3, 1, 2, 16, 1, 1});
885   /********************** 11 x 11 *********************/
886   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
887   b->Args({1, 11, 11, 3, 3, 1, 2, 1024, 1, 1});
888   b->Args({1, 11, 11, 3, 3, 1, 2, 768, 1, 1});
889   b->Args({1, 11, 11, 3, 3, 1, 2, 512, 1, 1});
890   b->Args({1, 11, 11, 3, 3, 1, 2, 384, 1, 1});
891   b->Args({1, 11, 11, 3, 3, 1, 2, 256, 1, 1});
892   b->Args({1, 11, 11, 3, 3, 1, 2, 192, 1, 1});
893   b->Args({1, 11, 11, 3, 3, 1, 2, 128, 1, 1});
894   b->Args({1, 11, 11, 3, 3, 1, 2, 64, 1, 1});
895   b->Args({1, 11, 11, 3, 3, 1, 2, 32, 1, 1});
896   b->Args({1, 11, 11, 3, 3, 1, 2, 16, 1, 1});
897   /*********************** 7 x 7 **********************/
898   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
899   b->Args({1, 7, 7, 3, 3, 1, 2, 1024, 1, 1});
900   b->Args({1, 7, 7, 3, 3, 1, 2, 768, 1, 1});
901   b->Args({1, 7, 7, 3, 3, 1, 2, 512, 1, 1});
902   b->Args({1, 7, 7, 3, 3, 1, 2, 384, 1, 1});
903   b->Args({1, 7, 7, 3, 3, 1, 2, 256, 1, 1});
904   b->Args({1, 7, 7, 3, 3, 1, 2, 128, 1, 1});
905   b->Args({1, 7, 7, 3, 3, 1, 2, 64, 1, 1});
906   b->Args({1, 7, 7, 3, 3, 1, 2, 32, 1, 1});
907   b->Args({1, 7, 7, 3, 3, 1, 2, 16, 1, 1});
908 }
909 
DWConv5x5(benchmark::internal::Benchmark * b)910 static void DWConv5x5(benchmark::internal::Benchmark* b) {
911   b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"});
912 
913   /********************** 96 x 96 *********************/
914   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
915   b->Args({1, 96, 96, 5, 5, 1, 1, 512, 1, 1});
916   b->Args({1, 96, 96, 5, 5, 1, 1, 256, 1, 1});
917   b->Args({1, 96, 96, 5, 5, 1, 1, 128, 1, 1});
918   b->Args({1, 96, 96, 5, 5, 1, 1, 64, 1, 1});
919   b->Args({1, 96, 96, 5, 5, 1, 1, 48, 1, 1});
920   b->Args({1, 96, 96, 5, 5, 1, 1, 32, 1, 1});
921   b->Args({1, 96, 96, 5, 5, 1, 1, 24, 1, 1});
922   b->Args({1, 96, 96, 5, 5, 1, 1, 16, 1, 1});
923   /********************** 32 x 32 *********************/
924   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
925   b->Args({1, 32, 32, 5, 5, 1, 1, 768, 1, 1});
926   b->Args({1, 32, 32, 5, 5, 1, 1, 512, 1, 1});
927   b->Args({1, 32, 32, 5, 5, 1, 1, 256, 1, 1});
928   b->Args({1, 32, 32, 5, 5, 1, 1, 128, 1, 1});
929   b->Args({1, 32, 32, 5, 5, 1, 1, 64, 1, 1});
930   b->Args({1, 32, 32, 5, 5, 1, 1, 48, 1, 1});
931   b->Args({1, 32, 32, 5, 5, 1, 1, 32, 1, 1});
932   b->Args({1, 32, 32, 5, 5, 1, 1, 24, 1, 1});
933   b->Args({1, 32, 32, 5, 5, 1, 1, 16, 1, 1});
934   /********************** 17 x 17 *********************/
935   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
936   b->Args({1, 17, 17, 5, 5, 1, 1, 1024, 1, 1});
937   b->Args({1, 17, 17, 5, 5, 1, 1, 768, 1, 1});
938   b->Args({1, 17, 17, 5, 5, 1, 1, 512, 1, 1});
939   b->Args({1, 17, 17, 5, 5, 1, 1, 384, 1, 1});
940   b->Args({1, 17, 17, 5, 5, 1, 1, 256, 1, 1});
941   b->Args({1, 17, 17, 5, 5, 1, 1, 128, 1, 1});
942   b->Args({1, 17, 17, 5, 5, 1, 1, 64, 1, 1});
943   b->Args({1, 17, 17, 5, 5, 1, 1, 32, 1, 1});
944   b->Args({1, 17, 17, 5, 5, 1, 1, 16, 1, 1});
945   /********************** 11 x 11 *********************/
946   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
947   b->Args({1, 11, 11, 5, 5, 1, 1, 1024, 1, 1});
948   b->Args({1, 11, 11, 5, 5, 1, 1, 768, 1, 1});
949   b->Args({1, 11, 11, 5, 5, 1, 1, 512, 1, 1});
950   b->Args({1, 11, 11, 5, 5, 1, 1, 384, 1, 1});
951   b->Args({1, 11, 11, 5, 5, 1, 1, 256, 1, 1});
952   b->Args({1, 11, 11, 5, 5, 1, 1, 128, 1, 1});
953   b->Args({1, 11, 11, 5, 5, 1, 1, 64, 1, 1});
954   b->Args({1, 11, 11, 5, 5, 1, 1, 32, 1, 1});
955   b->Args({1, 11, 11, 5, 5, 1, 1, 16, 1, 1});
956   /*********************** 7 x 7 **********************/
957   /*       N   H   W  KH  KW  S  D    G   GCin  GCout */
958   b->Args({1, 7, 7, 5, 5, 1, 1, 1024, 1, 1});
959   b->Args({1, 7, 7, 5, 5, 1, 1, 768, 1, 1});
960   b->Args({1, 7, 7, 5, 5, 1, 1, 512, 1, 1});
961   b->Args({1, 7, 7, 5, 5, 1, 1, 384, 1, 1});
962   b->Args({1, 7, 7, 5, 5, 1, 1, 256, 1, 1});
963   b->Args({1, 7, 7, 5, 5, 1, 1, 128, 1, 1});
964   b->Args({1, 7, 7, 5, 5, 1, 1, 64, 1, 1});
965   b->Args({1, 7, 7, 5, 5, 1, 1, 32, 1, 1});
966   b->Args({1, 7, 7, 5, 5, 1, 1, 16, 1, 1});
967 }
968 
969 BENCHMARK_CAPTURE(convolution_q8, mobilenet_v1, "MobileNet v1")
970     ->Apply(MobileNetV1);
971 BENCHMARK_CAPTURE(convolution_q8, mobilenet_v2, "MobileNet v2")
972     ->Apply(MobileNetV2);
973 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")
974     ->Apply(ShuffleNetV1G1);
975 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")
976     ->Apply(ShuffleNetV1G2);
977 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")
978     ->Apply(ShuffleNetV1G3);
979 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")
980     ->Apply(ShuffleNetV1G4);
981 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")
982     ->Apply(ShuffleNetV1G8);
983 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x05, "ShuffleNet v2 0.5X")
984     ->Apply(ShuffleNetV2X05);
985 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x10, "ShuffleNet v2 1.0X")
986     ->Apply(ShuffleNetV2X10);
987 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x15, "ShuffleNet v2 1.5X")
988     ->Apply(ShuffleNetV2X15);
989 BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x20, "ShuffleNet v2 2.0X")
990     ->Apply(ShuffleNetV2X20);
991 BENCHMARK_CAPTURE(convolution_q8, squeezenet_v10, "SqueezeNet 1.0")
992     ->Apply(SqueezeNetV10);
993 BENCHMARK_CAPTURE(convolution_q8, squeezenet_v11, "SqueezeNet 1.1")
994     ->Apply(SqueezeNetV11);
995 BENCHMARK_CAPTURE(convolution_q8, resnet18, "ResNet-18")->Apply(ResNet18);
996 BENCHMARK_CAPTURE(convolution_q8, resnet50, "ResNet-50")->Apply(ResNet50);
997 BENCHMARK_CAPTURE(convolution_q8, vgg, "VGG")->Apply(VGG);
998 BENCHMARK_CAPTURE(convolution_q8, dwconv3x3, "3x3 DW Convolutions")
999     ->Apply(DWConv3x3);
1000 BENCHMARK_CAPTURE(
1001     convolution_q8,
1002     dwconv3x3d2,
1003     "3x3 DW Convolutions (dilation 2)")
1004     ->Apply(DWConv3x3d2);
1005 BENCHMARK_CAPTURE(convolution_q8, dwconv5x5, "5x5 DW Convolutions")
1006     ->Apply(DWConv5x5);
1007 BENCHMARK_CAPTURE(convolution_q8, dwconv3x3_per_channel, "3x3 DW Convolutions", true)
1008     ->Apply(DWConv3x3);
1009 BENCHMARK_CAPTURE(
1010     convolution_q8,
1011     dwconv3x3d2_per_channel,
1012     "3x3 DW Convolutions (dilation 2)", true)
1013     ->Apply(DWConv3x3d2);
1014 BENCHMARK_CAPTURE(convolution_q8, dwconv5x5_per_channel, "5x5 DW Convolutions", true)
1015     ->Apply(DWConv5x5);
1016 
1017 #ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN
1018 BENCHMARK_MAIN();
1019 #endif
1020