1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16from primes import next_prime 17import xngen 18import xnncommon 19 20 21parser = argparse.ArgumentParser( 22 description='ArgMaxPool microkernel test generator') 23parser.add_argument("-s", "--spec", metavar="FILE", required=True, 24 help="Specification (YAML) file") 25parser.add_argument("-o", "--output", metavar="FILE", required=True, 26 help='Output (C++ source) file') 27parser.set_defaults(defines=list()) 28 29 30def split_ukernel_name(name): 31 match = re.fullmatch(r"xnn_(f16|f32)_argmaxpool_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)", name) 32 if match is None: 33 raise ValueError("Unexpected microkernel name: " + name) 34 35 if match.group(2): 36 primary_tile = int(match.group(3)) 37 incremental_tile = int(match.group(4)) 38 else: 39 primary_tile = int(match.group(4)) 40 incremental_tile = 0 41 42 channel_tile = int(match.group(6)) 43 44 arch, isa = xnncommon.parse_target_name(target_name=match.group(5)) 45 return primary_tile, incremental_tile, channel_tile, arch, isa 46 47 48ARGMAXPOOL_TEST_TEMPLATE = """\ 49$if INCREMENTAL_TILE == 0: 50 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile) { 51 $if ISA_CHECK: 52 ${ISA_CHECK}; 53 ArgMaxPoolMicrokernelTester() 54 .pooling_elements(${PRIMARY_TILE}) 55 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 56 .channels(${CHANNEL_TILE}) 57 .Test(${", ".join(TEST_ARGS)}); 58 } 59 60 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) { 61 $if ISA_CHECK: 62 ${ISA_CHECK}; 63 ArgMaxPoolMicrokernelTester() 64 .pooling_elements(${PRIMARY_TILE}) 65 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 66 .channels(${CHANNEL_TILE}) 67 .input_offset(${next_prime(CHANNEL_TILE+1)}) 68 .Test(${", ".join(TEST_ARGS)}); 69 } 70 71 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile) { 72 $if ISA_CHECK: 73 ${ISA_CHECK}; 74 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 75 ArgMaxPoolMicrokernelTester() 76 .pooling_elements(pooling_elements) 77 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 78 .channels(${CHANNEL_TILE}) 79 .Test(${", ".join(TEST_ARGS)}); 80 } 81 } 82 83 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile_with_input_offset) { 84 $if ISA_CHECK: 85 ${ISA_CHECK}; 86 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 87 ArgMaxPoolMicrokernelTester() 88 .pooling_elements(pooling_elements) 89 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 90 .channels(${CHANNEL_TILE}) 91 .input_offset(${next_prime(CHANNEL_TILE+1)}) 92 .Test(${", ".join(TEST_ARGS)}); 93 } 94 } 95 96 $if CHANNEL_TILE > 1: 97 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile) { 98 $if ISA_CHECK: 99 ${ISA_CHECK}; 100 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 101 ArgMaxPoolMicrokernelTester() 102 .pooling_elements(${PRIMARY_TILE}) 103 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 104 .channels(channels) 105 .Test(${", ".join(TEST_ARGS)}); 106 } 107 } 108 109 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) { 110 $if ISA_CHECK: 111 ${ISA_CHECK}; 112 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 113 ArgMaxPoolMicrokernelTester() 114 .pooling_elements(${PRIMARY_TILE}) 115 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 116 .channels(channels) 117 .input_offset(${next_prime(CHANNEL_TILE*8)}) 118 .Test(${", ".join(TEST_ARGS)}); 119 } 120 } 121 122 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile) { 123 $if ISA_CHECK: 124 ${ISA_CHECK}; 125 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 126 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 127 ArgMaxPoolMicrokernelTester() 128 .pooling_elements(pooling_elements) 129 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 130 .channels(channels) 131 .Test(${", ".join(TEST_ARGS)}); 132 } 133 } 134 } 135 136 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile_with_input_offset) { 137 $if ISA_CHECK: 138 ${ISA_CHECK}; 139 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 140 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 141 ArgMaxPoolMicrokernelTester() 142 .pooling_elements(pooling_elements) 143 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 144 .channels(channels) 145 .input_offset(${next_prime(CHANNEL_TILE*8)}) 146 .Test(${", ".join(TEST_ARGS)}); 147 } 148 } 149 } 150 151 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile) { 152 $if ISA_CHECK: 153 ${ISA_CHECK}; 154 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 155 ArgMaxPoolMicrokernelTester() 156 .pooling_elements(${PRIMARY_TILE}) 157 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 158 .channels(channels) 159 .Test(${", ".join(TEST_ARGS)}); 160 } 161 } 162 163 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) { 164 $if ISA_CHECK: 165 ${ISA_CHECK}; 166 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 167 ArgMaxPoolMicrokernelTester() 168 .pooling_elements(${PRIMARY_TILE}) 169 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 170 .channels(channels) 171 .input_offset(${next_prime(CHANNEL_TILE)}) 172 .Test(${", ".join(TEST_ARGS)}); 173 } 174 } 175 176 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile) { 177 $if ISA_CHECK: 178 ${ISA_CHECK}; 179 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 180 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 181 ArgMaxPoolMicrokernelTester() 182 .pooling_elements(pooling_elements) 183 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 184 .channels(channels) 185 .Test(${", ".join(TEST_ARGS)}); 186 } 187 } 188 } 189 190 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) { 191 $if ISA_CHECK: 192 ${ISA_CHECK}; 193 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 194 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 195 ArgMaxPoolMicrokernelTester() 196 .pooling_elements(pooling_elements) 197 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 198 .channels(channels) 199 .input_offset(${next_prime(CHANNEL_TILE)}) 200 .Test(${", ".join(TEST_ARGS)}); 201 } 202 } 203 } 204 205 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile) { 206 $if ISA_CHECK: 207 ${ISA_CHECK}; 208 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 209 ArgMaxPoolMicrokernelTester() 210 .pooling_elements(${PRIMARY_TILE}) 211 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 212 .channels(channels) 213 .Test(${", ".join(TEST_ARGS)}); 214 } 215 } 216 217 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) { 218 $if ISA_CHECK: 219 ${ISA_CHECK}; 220 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 221 ArgMaxPoolMicrokernelTester() 222 .pooling_elements(${PRIMARY_TILE}) 223 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 224 .channels(channels) 225 .input_offset(${next_prime(CHANNEL_TILE*2)}) 226 .Test(${", ".join(TEST_ARGS)}); 227 } 228 } 229 230 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile) { 231 $if ISA_CHECK: 232 ${ISA_CHECK}; 233 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 234 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 235 ArgMaxPoolMicrokernelTester() 236 .pooling_elements(pooling_elements) 237 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 238 .channels(channels) 239 .Test(${", ".join(TEST_ARGS)}); 240 } 241 } 242 } 243 244 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) { 245 $if ISA_CHECK: 246 ${ISA_CHECK}; 247 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) { 248 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 249 ArgMaxPoolMicrokernelTester() 250 .pooling_elements(pooling_elements) 251 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 252 .channels(channels) 253 .input_offset(${next_prime(CHANNEL_TILE*2)}) 254 .Test(${", ".join(TEST_ARGS)}); 255 } 256 } 257 } 258 259$if INCREMENTAL_TILE != 0: 260 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile) { 261 $if ISA_CHECK: 262 ${ISA_CHECK}; 263 ArgMaxPoolMicrokernelTester() 264 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 265 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 266 .channels(${CHANNEL_TILE}) 267 .Test(${", ".join(TEST_ARGS)}); 268 } 269 270 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) { 271 $if ISA_CHECK: 272 ${ISA_CHECK}; 273 ArgMaxPoolMicrokernelTester() 274 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 275 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 276 .channels(${CHANNEL_TILE}) 277 .input_offset(${next_prime(CHANNEL_TILE+1)}) 278 .Test(${", ".join(TEST_ARGS)}); 279 } 280 281 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile) { 282 $if ISA_CHECK: 283 ${ISA_CHECK}; 284 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 285 ArgMaxPoolMicrokernelTester() 286 .pooling_elements(pooling_elements) 287 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 288 .channels(${CHANNEL_TILE}) 289 .Test(${", ".join(TEST_ARGS)}); 290 } 291 } 292 293 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile_with_input_offset) { 294 $if ISA_CHECK: 295 ${ISA_CHECK}; 296 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 297 ArgMaxPoolMicrokernelTester() 298 .pooling_elements(pooling_elements) 299 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 300 .channels(${CHANNEL_TILE}) 301 .input_offset(${next_prime(CHANNEL_TILE+1)}) 302 .Test(${", ".join(TEST_ARGS)}); 303 } 304 } 305 306 $if CHANNEL_TILE > 1: 307 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile) { 308 $if ISA_CHECK: 309 ${ISA_CHECK}; 310 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 311 ArgMaxPoolMicrokernelTester() 312 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 313 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 314 .channels(channels) 315 .Test(${", ".join(TEST_ARGS)}); 316 } 317 } 318 319 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) { 320 $if ISA_CHECK: 321 ${ISA_CHECK}; 322 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 323 ArgMaxPoolMicrokernelTester() 324 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 325 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 326 .channels(channels) 327 .input_offset(${next_prime(CHANNEL_TILE*5)}) 328 .Test(${", ".join(TEST_ARGS)}); 329 } 330 } 331 332 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile) { 333 $if ISA_CHECK: 334 ${ISA_CHECK}; 335 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 336 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 337 ArgMaxPoolMicrokernelTester() 338 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 339 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 340 .channels(channels) 341 .Test(${", ".join(TEST_ARGS)}); 342 } 343 } 344 } 345 346 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile_with_input_offset) { 347 $if ISA_CHECK: 348 ${ISA_CHECK}; 349 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 350 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 351 ArgMaxPoolMicrokernelTester() 352 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 353 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 354 .channels(channels) 355 .input_offset(${next_prime(CHANNEL_TILE*8)}) 356 .Test(${", ".join(TEST_ARGS)}); 357 } 358 } 359 } 360 361 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile) { 362 $if ISA_CHECK: 363 ${ISA_CHECK}; 364 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 365 ArgMaxPoolMicrokernelTester() 366 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 367 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 368 .channels(channels) 369 .Test(${", ".join(TEST_ARGS)}); 370 } 371 } 372 373 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) { 374 $if ISA_CHECK: 375 ${ISA_CHECK}; 376 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 377 ArgMaxPoolMicrokernelTester() 378 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 379 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 380 .channels(channels) 381 .input_offset(${next_prime(CHANNEL_TILE)}) 382 .Test(${", ".join(TEST_ARGS)}); 383 } 384 } 385 386 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile) { 387 $if ISA_CHECK: 388 ${ISA_CHECK}; 389 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 390 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 391 ArgMaxPoolMicrokernelTester() 392 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 393 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 394 .channels(channels) 395 .Test(${", ".join(TEST_ARGS)}); 396 } 397 } 398 } 399 400 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) { 401 $if ISA_CHECK: 402 ${ISA_CHECK}; 403 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 404 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 405 ArgMaxPoolMicrokernelTester() 406 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 407 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 408 .channels(channels) 409 .input_offset(${next_prime(CHANNEL_TILE)}) 410 .Test(${", ".join(TEST_ARGS)}); 411 } 412 } 413 } 414 415 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile) { 416 $if ISA_CHECK: 417 ${ISA_CHECK}; 418 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 419 ArgMaxPoolMicrokernelTester() 420 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 421 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 422 .channels(channels) 423 .Test(${", ".join(TEST_ARGS)}); 424 } 425 } 426 427 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) { 428 $if ISA_CHECK: 429 ${ISA_CHECK}; 430 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 431 ArgMaxPoolMicrokernelTester() 432 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 433 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 434 .channels(channels) 435 .input_offset(${next_prime(CHANNEL_TILE*2)}) 436 .Test(${", ".join(TEST_ARGS)}); 437 } 438 } 439 440 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile) { 441 $if ISA_CHECK: 442 ${ISA_CHECK}; 443 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 444 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 445 ArgMaxPoolMicrokernelTester() 446 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 447 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 448 .channels(channels) 449 .Test(${", ".join(TEST_ARGS)}); 450 } 451 } 452 } 453 454 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) { 455 $if ISA_CHECK: 456 ${ISA_CHECK}; 457 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) { 458 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 459 ArgMaxPoolMicrokernelTester() 460 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 461 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 462 .channels(channels) 463 .input_offset(${next_prime(CHANNEL_TILE*2)}) 464 .Test(${", ".join(TEST_ARGS)}); 465 } 466 } 467 } 468 469 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass) { 470 $if ISA_CHECK: 471 ${ISA_CHECK}; 472 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 473 ArgMaxPoolMicrokernelTester() 474 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 475 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 476 .channels(${CHANNEL_TILE}) 477 .Test(${", ".join(TEST_ARGS)}); 478 } 479 } 480 481 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_with_input_offset) { 482 $if ISA_CHECK: 483 ${ISA_CHECK}; 484 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 485 ArgMaxPoolMicrokernelTester() 486 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 487 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 488 .channels(${CHANNEL_TILE}) 489 .input_offset(${next_prime(CHANNEL_TILE+1)}) 490 .Test(${", ".join(TEST_ARGS)}); 491 } 492 } 493 494 $if CHANNEL_TILE > 1: 495 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass) { 496 $if ISA_CHECK: 497 ${ISA_CHECK}; 498 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 499 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 500 ArgMaxPoolMicrokernelTester() 501 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 502 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 503 .channels(channels) 504 .Test(${", ".join(TEST_ARGS)}); 505 } 506 } 507 } 508 509 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_with_input_offset) { 510 $if ISA_CHECK: 511 ${ISA_CHECK}; 512 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 513 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) { 514 ArgMaxPoolMicrokernelTester() 515 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 516 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 517 .channels(channels) 518 .input_offset(${next_prime(CHANNEL_TILE*8)}) 519 .Test(${", ".join(TEST_ARGS)}); 520 } 521 } 522 } 523 524 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass) { 525 $if ISA_CHECK: 526 ${ISA_CHECK}; 527 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 528 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 529 ArgMaxPoolMicrokernelTester() 530 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 531 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 532 .channels(channels) 533 .Test(${", ".join(TEST_ARGS)}); 534 } 535 } 536 } 537 538 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_with_input_offset) { 539 $if ISA_CHECK: 540 ${ISA_CHECK}; 541 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 542 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 543 ArgMaxPoolMicrokernelTester() 544 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 545 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 546 .channels(channels) 547 .input_offset(${CHANNEL_TILE}) 548 .Test(${", ".join(TEST_ARGS)}); 549 } 550 } 551 } 552 553 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass) { 554 $if ISA_CHECK: 555 ${ISA_CHECK}; 556 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 557 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 558 ArgMaxPoolMicrokernelTester() 559 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 560 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 561 .channels(channels) 562 .Test(${", ".join(TEST_ARGS)}); 563 } 564 } 565 } 566 567 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_with_input_offset) { 568 $if ISA_CHECK: 569 ${ISA_CHECK}; 570 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) { 571 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 572 ArgMaxPoolMicrokernelTester() 573 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE}) 574 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 575 .channels(channels) 576 .input_offset(${next_prime(CHANNEL_TILE*2)}) 577 .Test(${", ".join(TEST_ARGS)}); 578 } 579 } 580 } 581 582$if INCREMENTAL_TILE == 0: 583 $MIN_POOLING, MAX_POOLING = 2, PRIMARY_TILE 584$else: 585 $MIN_POOLING, MAX_POOLING = PRIMARY_TILE + 1, PRIMARY_TILE + INCREMENTAL_TILE 586 587TEST(${TEST_NAME}, few_output_pixels) { 588 $if ISA_CHECK: 589 ${ISA_CHECK}; 590 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) { 591 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) { 592 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 593 ArgMaxPoolMicrokernelTester() 594 .output_pixels(output_pixels) 595 .pooling_elements(pooling_elements) 596 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 597 .channels(channels) 598 .Test(${", ".join(TEST_ARGS)}); 599 } 600 } 601 } 602} 603 604TEST(${TEST_NAME}, few_output_pixels_with_input_offset) { 605 $if ISA_CHECK: 606 ${ISA_CHECK}; 607 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) { 608 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) { 609 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 610 ArgMaxPoolMicrokernelTester() 611 .output_pixels(output_pixels) 612 .pooling_elements(pooling_elements) 613 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 614 .channels(channels) 615 .input_offset(${next_prime(CHANNEL_TILE*5+1)}) 616 .Test(${", ".join(TEST_ARGS)}); 617 } 618 } 619 } 620} 621 622TEST(${TEST_NAME}, few_output_pixels_with_output_stride) { 623 $if ISA_CHECK: 624 ${ISA_CHECK}; 625 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) { 626 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) { 627 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 628 ArgMaxPoolMicrokernelTester() 629 .output_pixels(output_pixels) 630 .pooling_elements(pooling_elements) 631 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 632 .channels(channels) 633 .output_stride(${next_prime(CHANNEL_TILE*5+1)}) 634 .Test(${", ".join(TEST_ARGS)}); 635 } 636 } 637 } 638} 639 640TEST(${TEST_NAME}, few_output_pixels_with_step) { 641 $if ISA_CHECK: 642 ${ISA_CHECK}; 643 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) { 644 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) { 645 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 646 for (size_t step = 2; step <= pooling_elements; step++) { 647 ArgMaxPoolMicrokernelTester() 648 .output_pixels(output_pixels) 649 .pooling_elements(pooling_elements) 650 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))}) 651 .step(step) 652 .channels(channels) 653 .output_stride(${next_prime(CHANNEL_TILE*5+1)}) 654 .Test(${", ".join(TEST_ARGS)}); 655 } 656 } 657 } 658 } 659} 660""" 661 662 663def generate_test_cases(ukernel, primary_tile, incremental_tile, channel_tile, 664 isa): 665 """Generates all tests cases for a ARGMAXPOOL micro-kernel. 666 667 Args: 668 ukernel: C name of the micro-kernel function. 669 primary_tile: Number of rows (pixels) processed per one iteration of the 670 primary outer loop of the micro-kernel. 671 incremental_tile: Number of rows (pixels) processed per one iteration of 672 the incremental outer loop of the micro-kernel. 673 channel_tile: Number of channels processed per one iteration of the inner 674 loops of the micro-kernel. 675 isa: instruction set required to run the micro-kernel. Generated unit test 676 will skip execution if the host processor doesn't support this ISA. 677 678 Returns: 679 Code for the test case. 680 """ 681 _, test_name = ukernel.split("_", 1) 682 _, datatype, ukernel_type, _ = ukernel.split("_", 3) 683 test_args = [ukernel] 684 if not isa: 685 test_args.append("ArgMaxPoolMicrokernelTester::Variant::Scalar") 686 return xngen.preprocess(ARGMAXPOOL_TEST_TEMPLATE, { 687 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 688 "TEST_ARGS": test_args, 689 "DATATYPE": datatype, 690 "PRIMARY_TILE": primary_tile, 691 "INCREMENTAL_TILE": incremental_tile, 692 "CHANNEL_TILE": channel_tile, 693 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 694 "next_prime": next_prime, 695 }) 696 697 698def main(args): 699 options = parser.parse_args(args) 700 701 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 702 spec_yaml = yaml.safe_load(spec_file) 703 if not isinstance(spec_yaml, list): 704 raise ValueError("expected a list of micro-kernels in the spec") 705 706 tests = """\ 707// Copyright 2019 Google LLC 708// 709// This source code is licensed under the BSD-style license found in the 710// LICENSE file in the root directory of this source tree. 711// 712// Auto-generated file. Do not edit! 713// Specification: {specification} 714// Generator: {generator} 715 716 717#include <gtest/gtest.h> 718 719#include <xnnpack/common.h> 720#include <xnnpack/isa-checks.h> 721 722#include <xnnpack/argmaxpool.h> 723#include "argmaxpool-microkernel-tester.h" 724""".format(specification=options.spec, generator=sys.argv[0]) 725 726 for ukernel_spec in spec_yaml: 727 name = ukernel_spec["name"] 728 primary_tile, incremental_tile, channel_tile, arch, isa = \ 729 split_ukernel_name(name) 730 731 # specification can override architecture 732 arch = ukernel_spec.get("arch", arch) 733 734 test_case = generate_test_cases(name, primary_tile, incremental_tile, 735 channel_tile, isa) 736 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 737 738 txt_changed = True 739 if os.path.exists(options.output): 740 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 741 txt_changed = output_file.read() != tests 742 743 if txt_changed: 744 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 745 output_file.write(tests) 746 747 748if __name__ == "__main__": 749 main(sys.argv[1:]) 750