xref: /aosp_15_r20/external/XNNPACK/tools/generate-argmaxpool-test.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(
22  description='ArgMaxPool microkernel test generator')
23parser.add_argument("-s", "--spec", metavar="FILE", required=True,
24                    help="Specification (YAML) file")
25parser.add_argument("-o", "--output", metavar="FILE", required=True,
26                    help='Output (C++ source) file')
27parser.set_defaults(defines=list())
28
29
30def split_ukernel_name(name):
31  match = re.fullmatch(r"xnn_(f16|f32)_argmaxpool_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)", name)
32  if match is None:
33    raise ValueError("Unexpected microkernel name: " + name)
34
35  if match.group(2):
36    primary_tile = int(match.group(3))
37    incremental_tile = int(match.group(4))
38  else:
39    primary_tile = int(match.group(4))
40    incremental_tile = 0
41
42  channel_tile = int(match.group(6))
43
44  arch, isa = xnncommon.parse_target_name(target_name=match.group(5))
45  return primary_tile, incremental_tile, channel_tile, arch, isa
46
47
48ARGMAXPOOL_TEST_TEMPLATE = """\
49$if INCREMENTAL_TILE == 0:
50  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile) {
51    $if ISA_CHECK:
52      ${ISA_CHECK};
53    ArgMaxPoolMicrokernelTester()
54      .pooling_elements(${PRIMARY_TILE})
55      .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
56      .channels(${CHANNEL_TILE})
57      .Test(${", ".join(TEST_ARGS)});
58  }
59
60  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
61    $if ISA_CHECK:
62      ${ISA_CHECK};
63    ArgMaxPoolMicrokernelTester()
64      .pooling_elements(${PRIMARY_TILE})
65      .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
66      .channels(${CHANNEL_TILE})
67      .input_offset(${next_prime(CHANNEL_TILE+1)})
68      .Test(${", ".join(TEST_ARGS)});
69  }
70
71  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile) {
72    $if ISA_CHECK:
73      ${ISA_CHECK};
74    for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
75      ArgMaxPoolMicrokernelTester()
76        .pooling_elements(pooling_elements)
77        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
78        .channels(${CHANNEL_TILE})
79        .Test(${", ".join(TEST_ARGS)});
80    }
81  }
82
83  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
84    $if ISA_CHECK:
85      ${ISA_CHECK};
86    for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
87      ArgMaxPoolMicrokernelTester()
88        .pooling_elements(pooling_elements)
89        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
90        .channels(${CHANNEL_TILE})
91        .input_offset(${next_prime(CHANNEL_TILE+1)})
92        .Test(${", ".join(TEST_ARGS)});
93    }
94  }
95
96  $if CHANNEL_TILE > 1:
97    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile) {
98      $if ISA_CHECK:
99        ${ISA_CHECK};
100      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
101        ArgMaxPoolMicrokernelTester()
102          .pooling_elements(${PRIMARY_TILE})
103          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
104          .channels(channels)
105          .Test(${", ".join(TEST_ARGS)});
106      }
107    }
108
109    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
110      $if ISA_CHECK:
111        ${ISA_CHECK};
112      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
113        ArgMaxPoolMicrokernelTester()
114          .pooling_elements(${PRIMARY_TILE})
115          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
116          .channels(channels)
117          .input_offset(${next_prime(CHANNEL_TILE*8)})
118          .Test(${", ".join(TEST_ARGS)});
119      }
120    }
121
122    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile) {
123      $if ISA_CHECK:
124        ${ISA_CHECK};
125      for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
126        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
127          ArgMaxPoolMicrokernelTester()
128            .pooling_elements(pooling_elements)
129            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
130            .channels(channels)
131            .Test(${", ".join(TEST_ARGS)});
132        }
133      }
134    }
135
136    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
137      $if ISA_CHECK:
138        ${ISA_CHECK};
139      for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
140        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
141          ArgMaxPoolMicrokernelTester()
142            .pooling_elements(pooling_elements)
143            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
144            .channels(channels)
145            .input_offset(${next_prime(CHANNEL_TILE*8)})
146            .Test(${", ".join(TEST_ARGS)});
147        }
148      }
149    }
150
151    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile) {
152      $if ISA_CHECK:
153        ${ISA_CHECK};
154      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
155        ArgMaxPoolMicrokernelTester()
156          .pooling_elements(${PRIMARY_TILE})
157          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
158          .channels(channels)
159          .Test(${", ".join(TEST_ARGS)});
160      }
161    }
162
163    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
164      $if ISA_CHECK:
165        ${ISA_CHECK};
166      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
167        ArgMaxPoolMicrokernelTester()
168          .pooling_elements(${PRIMARY_TILE})
169          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
170          .channels(channels)
171          .input_offset(${next_prime(CHANNEL_TILE)})
172          .Test(${", ".join(TEST_ARGS)});
173      }
174    }
175
176    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile) {
177      $if ISA_CHECK:
178        ${ISA_CHECK};
179      for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
180        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
181          ArgMaxPoolMicrokernelTester()
182            .pooling_elements(pooling_elements)
183            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
184            .channels(channels)
185            .Test(${", ".join(TEST_ARGS)});
186        }
187      }
188    }
189
190    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
191      $if ISA_CHECK:
192        ${ISA_CHECK};
193      for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
194        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
195          ArgMaxPoolMicrokernelTester()
196            .pooling_elements(pooling_elements)
197            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
198            .channels(channels)
199            .input_offset(${next_prime(CHANNEL_TILE)})
200            .Test(${", ".join(TEST_ARGS)});
201        }
202      }
203    }
204
205  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile) {
206    $if ISA_CHECK:
207      ${ISA_CHECK};
208    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
209      ArgMaxPoolMicrokernelTester()
210        .pooling_elements(${PRIMARY_TILE})
211        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
212        .channels(channels)
213        .Test(${", ".join(TEST_ARGS)});
214    }
215  }
216
217  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
218    $if ISA_CHECK:
219      ${ISA_CHECK};
220    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
221      ArgMaxPoolMicrokernelTester()
222        .pooling_elements(${PRIMARY_TILE})
223        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
224        .channels(channels)
225        .input_offset(${next_prime(CHANNEL_TILE*2)})
226        .Test(${", ".join(TEST_ARGS)});
227    }
228  }
229
230  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile) {
231    $if ISA_CHECK:
232      ${ISA_CHECK};
233    for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
234      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
235        ArgMaxPoolMicrokernelTester()
236          .pooling_elements(pooling_elements)
237          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
238          .channels(channels)
239          .Test(${", ".join(TEST_ARGS)});
240      }
241    }
242  }
243
244  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
245    $if ISA_CHECK:
246      ${ISA_CHECK};
247    for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
248      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
249        ArgMaxPoolMicrokernelTester()
250          .pooling_elements(pooling_elements)
251          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
252          .channels(channels)
253          .input_offset(${next_prime(CHANNEL_TILE*2)})
254          .Test(${", ".join(TEST_ARGS)});
255      }
256    }
257  }
258
259$if INCREMENTAL_TILE != 0:
260  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile) {
261    $if ISA_CHECK:
262      ${ISA_CHECK};
263    ArgMaxPoolMicrokernelTester()
264      .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
265      .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
266      .channels(${CHANNEL_TILE})
267      .Test(${", ".join(TEST_ARGS)});
268  }
269
270  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
271    $if ISA_CHECK:
272      ${ISA_CHECK};
273    ArgMaxPoolMicrokernelTester()
274      .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
275      .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
276      .channels(${CHANNEL_TILE})
277      .input_offset(${next_prime(CHANNEL_TILE+1)})
278      .Test(${", ".join(TEST_ARGS)});
279  }
280
281  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile) {
282    $if ISA_CHECK:
283      ${ISA_CHECK};
284    for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
285      ArgMaxPoolMicrokernelTester()
286        .pooling_elements(pooling_elements)
287        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
288        .channels(${CHANNEL_TILE})
289        .Test(${", ".join(TEST_ARGS)});
290    }
291  }
292
293  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
294    $if ISA_CHECK:
295      ${ISA_CHECK};
296    for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
297      ArgMaxPoolMicrokernelTester()
298        .pooling_elements(pooling_elements)
299        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
300        .channels(${CHANNEL_TILE})
301        .input_offset(${next_prime(CHANNEL_TILE+1)})
302        .Test(${", ".join(TEST_ARGS)});
303    }
304  }
305
306  $if CHANNEL_TILE > 1:
307    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile) {
308      $if ISA_CHECK:
309        ${ISA_CHECK};
310      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
311        ArgMaxPoolMicrokernelTester()
312          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
313          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
314          .channels(channels)
315          .Test(${", ".join(TEST_ARGS)});
316      }
317    }
318
319    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
320      $if ISA_CHECK:
321        ${ISA_CHECK};
322      for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
323        ArgMaxPoolMicrokernelTester()
324          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
325          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
326          .channels(channels)
327          .input_offset(${next_prime(CHANNEL_TILE*5)})
328          .Test(${", ".join(TEST_ARGS)});
329      }
330    }
331
332    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile) {
333      $if ISA_CHECK:
334        ${ISA_CHECK};
335      for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
336        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
337          ArgMaxPoolMicrokernelTester()
338            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
339            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
340            .channels(channels)
341            .Test(${", ".join(TEST_ARGS)});
342        }
343      }
344    }
345
346    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
347      $if ISA_CHECK:
348        ${ISA_CHECK};
349      for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
350        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
351          ArgMaxPoolMicrokernelTester()
352            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
353            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
354            .channels(channels)
355            .input_offset(${next_prime(CHANNEL_TILE*8)})
356            .Test(${", ".join(TEST_ARGS)});
357        }
358      }
359    }
360
361    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile) {
362      $if ISA_CHECK:
363        ${ISA_CHECK};
364      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
365        ArgMaxPoolMicrokernelTester()
366          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
367          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
368          .channels(channels)
369          .Test(${", ".join(TEST_ARGS)});
370      }
371    }
372
373    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
374      $if ISA_CHECK:
375        ${ISA_CHECK};
376      for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
377        ArgMaxPoolMicrokernelTester()
378          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
379          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
380          .channels(channels)
381          .input_offset(${next_prime(CHANNEL_TILE)})
382          .Test(${", ".join(TEST_ARGS)});
383      }
384    }
385
386    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile) {
387      $if ISA_CHECK:
388        ${ISA_CHECK};
389      for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
390        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
391          ArgMaxPoolMicrokernelTester()
392            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
393            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
394            .channels(channels)
395            .Test(${", ".join(TEST_ARGS)});
396        }
397      }
398    }
399
400    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
401      $if ISA_CHECK:
402        ${ISA_CHECK};
403      for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
404        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
405          ArgMaxPoolMicrokernelTester()
406            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
407            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
408            .channels(channels)
409            .input_offset(${next_prime(CHANNEL_TILE)})
410            .Test(${", ".join(TEST_ARGS)});
411        }
412      }
413    }
414
415  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile) {
416    $if ISA_CHECK:
417      ${ISA_CHECK};
418    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
419      ArgMaxPoolMicrokernelTester()
420        .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
421        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
422        .channels(channels)
423        .Test(${", ".join(TEST_ARGS)});
424    }
425  }
426
427  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
428    $if ISA_CHECK:
429      ${ISA_CHECK};
430    for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
431      ArgMaxPoolMicrokernelTester()
432        .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
433        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
434        .channels(channels)
435        .input_offset(${next_prime(CHANNEL_TILE*2)})
436        .Test(${", ".join(TEST_ARGS)});
437    }
438  }
439
440  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile) {
441    $if ISA_CHECK:
442      ${ISA_CHECK};
443    for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
444      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
445        ArgMaxPoolMicrokernelTester()
446          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
447          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
448          .channels(channels)
449          .Test(${", ".join(TEST_ARGS)});
450      }
451    }
452  }
453
454  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
455    $if ISA_CHECK:
456      ${ISA_CHECK};
457    for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
458      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
459        ArgMaxPoolMicrokernelTester()
460          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
461          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
462          .channels(channels)
463          .input_offset(${next_prime(CHANNEL_TILE*2)})
464          .Test(${", ".join(TEST_ARGS)});
465      }
466    }
467  }
468
469  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass) {
470    $if ISA_CHECK:
471      ${ISA_CHECK};
472    for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
473      ArgMaxPoolMicrokernelTester()
474        .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
475        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
476        .channels(${CHANNEL_TILE})
477        .Test(${", ".join(TEST_ARGS)});
478    }
479  }
480
481  TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_with_input_offset) {
482    $if ISA_CHECK:
483      ${ISA_CHECK};
484    for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
485      ArgMaxPoolMicrokernelTester()
486        .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
487        .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
488        .channels(${CHANNEL_TILE})
489        .input_offset(${next_prime(CHANNEL_TILE+1)})
490        .Test(${", ".join(TEST_ARGS)});
491    }
492  }
493
494  $if CHANNEL_TILE > 1:
495    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass) {
496      $if ISA_CHECK:
497        ${ISA_CHECK};
498      for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
499        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
500          ArgMaxPoolMicrokernelTester()
501            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
502            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
503            .channels(channels)
504            .Test(${", ".join(TEST_ARGS)});
505        }
506      }
507    }
508
509    TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_with_input_offset) {
510      $if ISA_CHECK:
511        ${ISA_CHECK};
512      for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
513        for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
514          ArgMaxPoolMicrokernelTester()
515            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
516            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
517            .channels(channels)
518            .input_offset(${next_prime(CHANNEL_TILE*8)})
519            .Test(${", ".join(TEST_ARGS)});
520        }
521      }
522    }
523
524    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass) {
525      $if ISA_CHECK:
526        ${ISA_CHECK};
527      for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
528        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
529          ArgMaxPoolMicrokernelTester()
530            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
531            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
532            .channels(channels)
533            .Test(${", ".join(TEST_ARGS)});
534        }
535      }
536    }
537
538    TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_with_input_offset) {
539      $if ISA_CHECK:
540        ${ISA_CHECK};
541      for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
542        for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
543          ArgMaxPoolMicrokernelTester()
544            .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
545            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
546            .channels(channels)
547            .input_offset(${CHANNEL_TILE})
548            .Test(${", ".join(TEST_ARGS)});
549        }
550      }
551    }
552
553  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass) {
554    $if ISA_CHECK:
555      ${ISA_CHECK};
556    for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
557      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
558        ArgMaxPoolMicrokernelTester()
559          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
560          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
561          .channels(channels)
562          .Test(${", ".join(TEST_ARGS)});
563      }
564    }
565  }
566
567  TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_with_input_offset) {
568    $if ISA_CHECK:
569      ${ISA_CHECK};
570    for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
571      for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
572        ArgMaxPoolMicrokernelTester()
573          .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
574          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
575          .channels(channels)
576          .input_offset(${next_prime(CHANNEL_TILE*2)})
577          .Test(${", ".join(TEST_ARGS)});
578      }
579    }
580  }
581
582$if INCREMENTAL_TILE == 0:
583  $MIN_POOLING, MAX_POOLING = 2, PRIMARY_TILE
584$else:
585  $MIN_POOLING, MAX_POOLING = PRIMARY_TILE + 1, PRIMARY_TILE + INCREMENTAL_TILE
586
587TEST(${TEST_NAME}, few_output_pixels) {
588  $if ISA_CHECK:
589    ${ISA_CHECK};
590  for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
591    for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
592      for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
593        ArgMaxPoolMicrokernelTester()
594          .output_pixels(output_pixels)
595          .pooling_elements(pooling_elements)
596          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
597          .channels(channels)
598          .Test(${", ".join(TEST_ARGS)});
599      }
600    }
601  }
602}
603
604TEST(${TEST_NAME}, few_output_pixels_with_input_offset) {
605  $if ISA_CHECK:
606    ${ISA_CHECK};
607  for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
608    for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
609      for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
610        ArgMaxPoolMicrokernelTester()
611          .output_pixels(output_pixels)
612          .pooling_elements(pooling_elements)
613          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
614          .channels(channels)
615          .input_offset(${next_prime(CHANNEL_TILE*5+1)})
616          .Test(${", ".join(TEST_ARGS)});
617      }
618    }
619  }
620}
621
622TEST(${TEST_NAME}, few_output_pixels_with_output_stride) {
623  $if ISA_CHECK:
624    ${ISA_CHECK};
625  for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
626    for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
627      for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
628        ArgMaxPoolMicrokernelTester()
629          .output_pixels(output_pixels)
630          .pooling_elements(pooling_elements)
631          .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
632          .channels(channels)
633          .output_stride(${next_prime(CHANNEL_TILE*5+1)})
634          .Test(${", ".join(TEST_ARGS)});
635      }
636    }
637  }
638}
639
640TEST(${TEST_NAME}, few_output_pixels_with_step) {
641  $if ISA_CHECK:
642    ${ISA_CHECK};
643  for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
644    for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
645      for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
646        for (size_t step = 2; step <= pooling_elements; step++) {
647          ArgMaxPoolMicrokernelTester()
648            .output_pixels(output_pixels)
649            .pooling_elements(pooling_elements)
650            .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
651            .step(step)
652            .channels(channels)
653            .output_stride(${next_prime(CHANNEL_TILE*5+1)})
654            .Test(${", ".join(TEST_ARGS)});
655        }
656      }
657    }
658  }
659}
660"""
661
662
663def generate_test_cases(ukernel, primary_tile, incremental_tile, channel_tile,
664                        isa):
665  """Generates all tests cases for a ARGMAXPOOL micro-kernel.
666
667  Args:
668    ukernel: C name of the micro-kernel function.
669    primary_tile: Number of rows (pixels) processed per one iteration of the
670                  primary outer loop of the micro-kernel.
671    incremental_tile: Number of rows (pixels) processed per one iteration of
672                      the incremental outer loop of the micro-kernel.
673    channel_tile: Number of channels processed per one iteration of the inner
674                  loops of the micro-kernel.
675    isa: instruction set required to run the micro-kernel. Generated unit test
676         will skip execution if the host processor doesn't support this ISA.
677
678  Returns:
679    Code for the test case.
680  """
681  _, test_name = ukernel.split("_", 1)
682  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
683  test_args = [ukernel]
684  if not isa:
685    test_args.append("ArgMaxPoolMicrokernelTester::Variant::Scalar")
686  return xngen.preprocess(ARGMAXPOOL_TEST_TEMPLATE, {
687      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
688      "TEST_ARGS": test_args,
689      "DATATYPE": datatype,
690      "PRIMARY_TILE": primary_tile,
691      "INCREMENTAL_TILE": incremental_tile,
692      "CHANNEL_TILE": channel_tile,
693      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
694      "next_prime": next_prime,
695    })
696
697
698def main(args):
699  options = parser.parse_args(args)
700
701  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
702    spec_yaml = yaml.safe_load(spec_file)
703    if not isinstance(spec_yaml, list):
704      raise ValueError("expected a list of micro-kernels in the spec")
705
706    tests = """\
707// Copyright 2019 Google LLC
708//
709// This source code is licensed under the BSD-style license found in the
710// LICENSE file in the root directory of this source tree.
711//
712// Auto-generated file. Do not edit!
713//   Specification: {specification}
714//   Generator: {generator}
715
716
717#include <gtest/gtest.h>
718
719#include <xnnpack/common.h>
720#include <xnnpack/isa-checks.h>
721
722#include <xnnpack/argmaxpool.h>
723#include "argmaxpool-microkernel-tester.h"
724""".format(specification=options.spec, generator=sys.argv[0])
725
726    for ukernel_spec in spec_yaml:
727      name = ukernel_spec["name"]
728      primary_tile, incremental_tile, channel_tile, arch, isa = \
729        split_ukernel_name(name)
730
731      # specification can override architecture
732      arch = ukernel_spec.get("arch", arch)
733
734      test_case = generate_test_cases(name, primary_tile, incremental_tile,
735                                      channel_tile, isa)
736      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
737
738    txt_changed = True
739    if os.path.exists(options.output):
740      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
741        txt_changed = output_file.read() != tests
742
743    if txt_changed:
744      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
745        output_file.write(tests)
746
747
748if __name__ == "__main__":
749  main(sys.argv[1:])
750