1#!/usr/bin/env python3 2# Copyright 2021 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6"""Converts hand written assembly (.S files) to C++ files using the JIT. 7 8Takes a single argument, an assembly file, and prints converted output to stdout. 9""" 10 11import argparse 12import datetime 13import re 14import sys 15 16SPACES = r'\s*' 17COMMA = r',' + SPACES 18COMMENTS = SPACES + '((//\s+.+)|)$' 19WB = r'!' 20 21REG_NO_GROUP = r'r\d+|s\d+|d\d+|q\d+|sp|lr|pc|x\d+|(?:v\d+\.(?:\d+)?(?:d|s|h|b))' 22REG = r'(' + REG_NO_GROUP + ')' 23IMM_NO_GROUP = r'\d+' 24IMM = r'(' + IMM_NO_GROUP + ')' 25REG_LANE_NO_GROUP = r'(?:' + REG_NO_GROUP + r')\[' + IMM_NO_GROUP + r'\]' 26REG_OR_IMM = r'(' + REG_LANE_NO_GROUP + '|' + REG_NO_GROUP + '|' + IMM_NO_GROUP + ')' 27 28REGLIST_CONSEC = r'\{(\w+)-(\w+)\}' + SPACES 29REGLIST_INDIV = r'\{([\w.]+(?:,\s+[\w.]+)*)\}' + SPACES 30REGLIST_INDIV_REPLICATE = r'\{(\w+(?:\[\])(,\s*\w+(?:\[\]))*)\}' + SPACES 31REGLIST_INDEX = r'\{(' + REG_LANE_NO_GROUP + ')\}' + SPACES 32 33APSR = 'APSR_nzcv' 34FPSCR = '(FPSCR)' 35 36MEMOP = r'\[' + SPACES + REG + '\]' + SPACES 37MEMOP_MAYBE_WB = r'\[' + SPACES + REG + '\]' + f'({WB})?' 38MEMOP_OFFSET = r'\[' + REG + COMMA + '(-?\d+)\]' + SPACES 39MEMOP_OFFSET_MAYBE_WB = r'\[' + REG + COMMA + '(-?\d+)\]' + f'({WB})?' + SPACES 40 41B_IMM = r'(\d+)(f|b)' 42 43INSTR = SPACES + r'([A-Z0-9.]+)' + SPACES 44 45# e.g. #ifndef __APPLE__ 46IFDEF_RE = re.compile(r'\s*#(ifndef|endif|ifdef).*') 47# e.g. # Push 96 bytes 48COMMENT_RE = re.compile(SPACES + r'((//|#)\s*.+)') 49# e.g. 0: 50LABEL = re.compile(r'(\w+):') 51# e.g. NOP 52INSTR_RE = re.compile(INSTR + COMMENTS) 53# e.g. VPUSH {d8-d15} 54INSTR_REGLIST_CONSEC_RE = re.compile(INSTR + REGLIST_CONSEC + COMMENTS) 55# e.g. PUSH {r4, r5} 56INSTR_REGLIST_LIST_RE = re.compile(INSTR + REGLIST_INDIV + COMMENTS) 57# e.g. BX lr 58INSTR_OP_RE = re.compile(INSTR + REG + COMMENTS) 59# e.g. BLO 2f 60INSTR_B_IMM = re.compile(INSTR + B_IMM + COMMENTS) 61# e.g. TBNZ x0, 4, 5f 62INSTR_B_REG_IMM_IMM = re.compile(INSTR + REG + COMMA + IMM + COMMA + B_IMM + COMMENTS) 63# e.g. .p2align 3 64P2ALIGN_RE = re.compile(SPACES + r'\.p2align\s+(\d+)') 65# e.g. CMP r0, 2 66INSTR_REG_IMM_RE = re.compile(INSTR + REG + COMMA + IMM + COMMENTS) 67# e.g. LDR r0, [r12] 68INSTR_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMENTS) 69# e.g. LDR q0, [x4], 16 70INSTR_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMA + IMM + COMMENTS) 71# e.g. LDR r0, [sp, 112], STR x20, [sp, -80]! 72INSTR_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + MEMOP_OFFSET_MAYBE_WB + 73 COMMENTS) 74# e.g. LDRD r6, r7, [sp] 75INSTR_REG_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + 76 MEMOP + COMMENTS) 77# e.g. LDRD r6, r7, [sp, 104], STP d8, d9, [sp, -64]! 78INSTR_REG_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + 79 MEMOP_OFFSET_MAYBE_WB + COMMENTS) 80# e.g. LDP q20, q21, [x5], 32 81INSTR_REG_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + 82 MEMOP + COMMA + IMM + COMMENTS) 83# e.g. PLD [r4, 64] 84INSTR_MEMOP_OFFSET_RE = re.compile(INSTR + MEMOP_OFFSET + COMMENTS) 85# e.g. movlo r12, r3, vdup.32 q0, d14[0] 86INSTR_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG_OR_IMM + COMMENTS) 87# e.g. SUBS r5, r2, 16 or SUBS r5, r2, r10 or VMLFA.F32 q8, q4, d0[0] 88INSTR_REG_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + 89 REG_OR_IMM + COMMENTS) 90# e.g. VEXT.8 q0, q0, q0, 4 91INSTR_REG_REG_REG_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG + 92 COMMA + IMM + COMMENTS) 93# e.g. VST1.32 {d16}, [r11], r0 94INSTR_REGLIST_INDIV_MEMOP_REG = re.compile(INSTR + REGLIST_INDIV + COMMA + 95 MEMOP + COMMA + REG + COMMENTS) 96# e.g. VST1.32 {d16-d19}, [r11], r0 97INSTR_REGLIST_CONSEC_MEMOP_REG = re.compile(INSTR + REGLIST_CONSEC + COMMA + 98 MEMOP + COMMA + REG + COMMENTS) 99# e.g. VLDM r9, {d16-d19} 100INSTR_REG_REGLIST_CONSECT = re.compile(INSTR + REG + COMMA + REGLIST_CONSEC + 101 COMMENTS) 102# e.g. VLDM r9!, {d16-d19} 103INSTR_REG_REGLIST_CONSECT_WB = re.compile(INSTR + REG + WB + COMMA + 104 REGLIST_CONSEC + COMMENTS) 105# e.g. VLDM r9!, {d16} 106INSTR_REG_REGLIST_INDIV_WB = re.compile(INSTR + REG + WB + COMMA + 107 REGLIST_INDIV + COMMENTS) 108# e.g. VLD1.32 {d0}, [r3]{!} 109INSTR_REGLIST_INDIV_MEMOP = re.compile(INSTR + REGLIST_INDIV + COMMA + 110 MEMOP_MAYBE_WB + COMMENTS) 111# e.g. LD1 {v16.16b, v17.16b, v18.16b}, [x5], 48 112INSTR_REGLIST_INDIV_MEMOP_IMM = re.compile(INSTR + REGLIST_INDIV + COMMA + 113 MEMOP + COMMA + IMM + COMMENTS) 114# e.g. VST1.32 {d24-d25}, [r11]{!} 115INSTR_REGLIST_CONSEC_MEMOP = re.compile(INSTR + REGLIST_CONSEC + COMMA + 116 MEMOP_MAYBE_WB + COMMENTS) 117# e.g. VLD1.32 {d0[]}, [r3]! 118INSTR_REGLIST_REPLICATE_MEMOP = re.compile(INSTR + REGLIST_INDIV_REPLICATE + 119 COMMA + MEMOP + r'(!)?' + COMMENTS) 120# e.g. VST1.32 {d16[0]}, [r11]{!} 121INSTR_REGLIST_INDEX_MEMOP = re.compile(INSTR + REGLIST_INDEX + COMMA + 122 MEMOP_MAYBE_WB + COMMENTS) 123# e.g. VMRS APSR_nzcv, FPSCR 124INSTR_REG_FPSCR = re.compile(INSTR + f'({APSR}|{REG_NO_GROUP})' + COMMA + 125 FPSCR + COMMENTS) 126 127# e.g. PRFM PLDL1KEEP, [x5] 128INSTR_PLD_MEMOP = re.compile(INSTR + f'(PLDL1KEEP)' + COMMA + MEMOP + COMMENTS) 129# e.g. PRFM PLDL1KEEP, [x5, 64] 130INSTR_PLD_MEMOP_OFFSET = re.compile(INSTR + f'(PLDL1KEEP)' + COMMA + MEMOP_OFFSET + COMMENTS) 131 132COND = r'([A-Z]+)' 133# e.g. CSEL x9, x3, x9, LO 134INSTR_REG_REG_REG_COND_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG + COMMA + COND + COMMENTS) 135 136 137def remove_brackets(s): 138 return s.replace('[', '').replace(']', '') 139 140 141def fix_replicate_instruction(s): 142 return re.sub(r'_(\d+)', r'r_\1', s, 1) 143 144 145def fix_instr_name(s): 146 return s.lower().replace('.', '_', 2).replace('and', 'and_', 1) 147 148 149def fix_comments(s): 150 return s.replace('#', '//', 1) 151 152 153def maybe_wb(wb): 154 return '++' if wb else '' 155 156 157def fix_fn_name(name): 158 if name.startswith('xnn_'): 159 name = name[len('xnn_'):] 160 # remove any type of activations from name 161 if 'minmax' in name: 162 name = name.replace('minmax_', '') 163 return f'xnn_generate_{name}' 164 165 166def remove_prfm_from_fn_name(name): 167 assert('_prfm_' in name) 168 return name.replace('prfm_', '') 169 170 171def fix_regs(regs): 172 # Vector registers with datatype need to be method calls. 173 # e.g. v2.4s -> v2.v4s(), v2.s -> v2.s() 174 def repl(m): 175 if m.group(2): 176 return f'{m[1]}v{m[2]}{m[3]}()' 177 else: 178 return f'{m[1]}{m[3]}()' 179 return re.sub(r'(\w+\.)(\d+)?(\w+)', repl, regs) 180 181 182IGNORE_LINES = [r'\s*\.\w+'] 183 184AARCH32 = 'aarch32' 185AARCH64 = 'aarch64' 186GEMM = 'GEMM' 187IGEMM = 'IGEMM' 188 189def main(input_file): 190 arch = None 191 kernel_type = GEMM 192 minmax = False 193 prfm = False 194 datatype = 'f32' 195 ctype = 'float' 196 197 if 'aarch32' in input_file: 198 arch = AARCH32 199 elif 'aarch64' in input_file: 200 arch = AARCH64 201 else: 202 print('ERROR: unknown architecture') 203 sys.exit(1) 204 205 if 'igemm' in input_file: 206 kernel_type = IGEMM 207 if 'minmax' in input_file: 208 minmax = True 209 if 'prfm' in input_file: 210 prfm = True 211 212 # Whether we are in the copyright section. 213 in_copyright = False 214 # Whether we are in the microkernel function. 215 in_function = False 216 # Instructions that make up the microkernel. 217 instructions = [] 218 # Lines of code or comments before the actual function body. 219 prologue = [] 220 # All labels need to be declared first, collect them and output them after 221 # function signature. 222 labels = [] 223 # Name of the microkernel function. 224 fn_name = '' 225 sc = ';' 226 # Whether we are in the auto-generated comment. 227 in_autogen = False 228 229 with open(input_file, 'r', encoding='utf-8') as f: 230 for line in f: 231 line = line.rstrip() 232 233 # Handle all lines before the microkernel instructions begin. 234 if not in_function: 235 if 'Auto-generated file' in line: 236 in_autogen = True 237 continue 238 elif 'BEGIN_FUNCTION' in line: 239 in_function = True 240 fn_name = line.split()[1] 241 prologue.append(f'// Converted from: {input_file[20:]}') 242 params = 'float min, float max' if minmax else 'void* params' 243 prefetch = 'bool prefetch, ' if prfm else '' 244 if kernel_type == GEMM: 245 prologue.append(f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params}) {{') 246 else: 247 prologue.append(f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params}) {{') 248 continue 249 elif 'Copyright ' in line: 250 in_autogen = False 251 # replace year 252 prologue.append( 253 re.sub('\d{4}', str(datetime.date.today().year), line, 254 1).rstrip()) 255 continue 256 elif '#include <xnnpack/assembly.h>' in line: 257 prologue.append(f'#include <cassert>') 258 prologue.append(f'#include <cstddef>') 259 prologue.append(f'#include <limits>') 260 prologue.append('') 261 prologue.append(f'#include <xnnpack/{arch}-assembler.h>') 262 prologue.append('#include <xnnpack/allocator.h>') 263 if kernel_type == GEMM: 264 prologue.append('#include <xnnpack/gemm.h>') 265 else: 266 prologue.append('#include <xnnpack/igemm.h>') 267 prologue.append('') 268 prologue.append('namespace xnnpack {') 269 prologue.append(f'namespace {arch} {{') 270 prologue.append('namespace {') 271 prologue.append('class Generator : public Assembler {') 272 prologue.append(' using Assembler::Assembler;') 273 prologue.append(' public:') 274 params = 'float min, float max' if minmax else 'void* params' 275 prefetch = 'bool prefetch, ' if prfm else '' 276 if kernel_type == GEMM: 277 prologue.append(f' void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params});') 278 else: 279 prologue.append(f' void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params});') 280 prologue.append('};') 281 continue 282 elif any(re.fullmatch(p, line) for p in IGNORE_LINES): 283 continue 284 elif in_autogen: 285 continue 286 else: 287 prologue.append(fix_comments(line.rstrip())) 288 continue 289 # end if not in_function 290 291 # We are now in the microkernel function body. 292 # Don't keep the ifdefs. 293 m = re.fullmatch(IFDEF_RE, line) 294 if m: 295 continue 296 # But keep other comments. 297 m = re.fullmatch(COMMENT_RE, line) 298 if m: 299 instructions.append(m[1]) 300 continue 301 302 m = re.fullmatch(LABEL, line) 303 if m: 304 labels.append(m[1]) 305 instructions.append(f'bind(l{m[1]}){sc}') 306 continue 307 m = re.fullmatch(INSTR_RE, line) 308 if m: 309 instructions.append(f'{fix_instr_name(m[1])}(){sc} {m[2]}') 310 continue 311 m = re.fullmatch(INSTR_OP_RE, line) 312 if m: 313 instructions.append(f'{fix_instr_name(m[1])}({m[2]}){sc} {m[3]}') 314 continue 315 m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP_REG, line) 316 if m: 317 instructions.append( 318 f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}], {m[5]}){sc} {m[6]}' 319 ) 320 continue 321 m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_REG, line) 322 if m: 323 instructions.append( 324 f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}') 325 continue 326 m = re.fullmatch(INSTR_REGLIST_CONSEC_RE, line) 327 if m: 328 instructions.append(f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}){sc} {m[4]}') 329 continue 330 m = re.fullmatch(INSTR_REGLIST_LIST_RE, line) 331 if m: 332 instructions.append(f'{fix_instr_name(m[1])}({{{m[2]}}}){sc} {m[3]}') 333 continue 334 m = re.fullmatch(INSTR_MEMOP_OFFSET_RE, line) 335 if m: 336 instructions.append(f'{fix_instr_name(m[1])}(mem[{m[2]}, {m[3]}]){sc} {m[4]}') 337 continue 338 m = re.fullmatch(INSTR_REG_MEMOP_RE, line) 339 if m: 340 instructions.append(f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}]){sc} {m[4]}') 341 continue 342 m = re.fullmatch(INSTR_REG_MEMOP_IMM_RE , line) 343 if m: 344 instructions.append(f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}], {m[4]}){sc} {m[5]}') 345 continue 346 m = re.fullmatch(INSTR_REG_MEMOP_OFFSET_RE, line) 347 if m: 348 if m[5]: # wb 349 instructions.append( 350 f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]++){sc} {m[6]}') 351 else: # no wb 352 instructions.append( 353 f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[6]}') 354 continue 355 m = re.fullmatch(INSTR_REG_REG_MEMOP_RE, line) 356 if m: 357 instructions.append( 358 f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}]){sc} {m[5]}') 359 continue 360 m = re.fullmatch(INSTR_REG_REG_MEMOP_OFFSET_RE, line) 361 if m: 362 if m[6]: # wb 363 instructions.append( 364 f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]++){sc} {m[7]}') 365 else: #no wb 366 instructions.append( 367 f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]){sc} {m[7]}') 368 continue 369 m = re.fullmatch(INSTR_REG_REG_MEMOP_IMM_RE , line) 370 if m: 371 instructions.append( 372 f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}], {m[5]}){sc} {m[6]}') 373 continue 374 m = re.fullmatch(INSTR_REG_IMM_RE, line) 375 if m: 376 instructions.append(f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {m[3]}){sc} {m[4]}') 377 continue 378 m = re.fullmatch(INSTR_REG_REG_REG_RE, line) 379 if m: 380 instructions.append( 381 f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}, {fix_regs(m[4])}){sc} {m[5]}') 382 continue 383 m = re.fullmatch(INSTR_REG_REG_REG_IMM_RE, line) 384 if m: 385 instructions.append( 386 f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, {m[5]}){sc} {m[6]}') 387 continue 388 m = re.fullmatch(INSTR_REG_REG_RE, line) 389 if m: 390 instructions.append(f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}){sc} {m[4]}') 391 continue 392 m = re.fullmatch(INSTR_REG_REGLIST_CONSECT, line) 393 if m: 394 instructions.append( 395 f'{fix_instr_name(m[1])}(mem[{m[2]}], {{{m[3]}-{m[4]}}}){sc} {m[5]}') 396 continue 397 m = re.fullmatch(INSTR_REG_REGLIST_CONSECT_WB, line) 398 if m: 399 instructions.append( 400 f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}-{m[4]}}}){sc} {m[5]}') 401 continue 402 m = re.fullmatch(INSTR_REG_REGLIST_INDIV_WB, line) 403 if m: 404 instructions.append( 405 f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}}}){sc} {m[4]}') 406 continue 407 m = re.fullmatch(INSTR_B_IMM, line) 408 if m: 409 instructions.append(f'{fix_instr_name(m[1])}(l{m[2]}){sc} {m[4]}') 410 continue 411 m = re.fullmatch(INSTR_B_REG_IMM_IMM , line) 412 if m: 413 instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, l{m[4]}){sc} {m[6]}') 414 continue 415 m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP, line) 416 if m: 417 instructions.append( 418 f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}' 419 ) 420 continue 421 m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_IMM, line) 422 if m: 423 instructions.append( 424 f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}' 425 ) 426 continue 427 m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP, line) 428 if m: 429 instructions.append( 430 f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}]{maybe_wb(m[5])}){sc} {m[6]}' 431 ) 432 continue 433 m = re.fullmatch(INSTR_REGLIST_REPLICATE_MEMOP, line) 434 if m: 435 if m[5]: 436 instructions.append( 437 f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]++){sc} {m[6]}' 438 ) 439 else: 440 instructions.append( 441 f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]){sc} {m[6]}' 442 ) 443 continue 444 m = re.fullmatch(INSTR_REGLIST_INDEX_MEMOP, line) 445 if m: 446 instructions.append( 447 f'{fix_instr_name(m[1])}({{{m[2]}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}' 448 ) 449 continue 450 m = re.fullmatch(P2ALIGN_RE, line) 451 if m: 452 instructions.append(f'align({1 << int(m[1])}){sc}') 453 continue 454 m = re.fullmatch(INSTR_REG_FPSCR, line) 455 if m: 456 instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}){sc} {m[4]}') 457 continue 458 m = re.fullmatch(INSTR_PLD_MEMOP, line) 459 if m: 460 instructions.append(f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}]){sc} {m[4]}') 461 continue 462 m = re.fullmatch(INSTR_PLD_MEMOP_OFFSET, line) 463 if m: 464 instructions.append(f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[5]}') 465 continue 466 m = re.fullmatch(INSTR_REG_REG_REG_COND_RE, line) 467 if m: 468 instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, k{m[5]}){sc} {m[6]}') 469 continue 470 471 # Keep empty lines for formatting 472 if line.strip() == '': 473 instructions.append('') 474 continue 475 476 # Assembly directives that we don't are about. 477 if line.strip().startswith('.'): 478 continue 479 480 if line.startswith('END_FUNCTION'): 481 break 482 483 # All other lines are error. 484 print(f'ERROR: {line}', file=sys.stderr) 485 sys.exit(1) 486 487 # Actually emit the JIT codegen (to stdout). 488 for p in prologue: 489 print(p) 490 491 492 m = re.search('(\d+)x(\d+)', input_file) 493 mr = 0 494 nr = 0 495 if m: 496 mr = m[1] 497 nr = m[2] 498 labels_str = ', '.join(f'l{l}' for l in labels) 499 print(f' assert(max_mr <= {mr});') 500 print(f' assert(nc_mod_nr < {nr});') 501 print(' assert(kc != 0);') 502 print(f' assert(kc % sizeof({ctype}) == 0);') 503 print() 504 print(f' Label {labels_str};') 505 print() 506 if minmax: 507 print(' // const bool clamp_min = min != -std::numeric_limits<float>::infinity();') 508 print(' // const bool clamp_max = max != +std::numeric_limits<float>::infinity();') 509 510 indent = ' ' 511 for i in instructions: 512 if i.strip().startswith('#'): 513 print(indent + fix_comments(i)) 514 elif i.strip().startswith('//'): 515 print(indent + i) 516 elif i.strip() == '': 517 print() 518 else: 519 print(indent + (i).rstrip()) 520 print(indent + 'align(16, AlignInstruction::kHlt);') 521 522 print('}') 523 print('} // namespace') 524 print(f'}} // {arch}') 525 print('} // xnnpack') 526 print('') 527 if prfm: 528 print_generator_definition(kernel_type, remove_prfm_from_fn_name(fn_name), arch, minmax, prefetch='false, ') 529 print() 530 print_generator_definition(kernel_type, fn_name, arch, minmax, prefetch='true, ') 531 else: 532 print_generator_definition(kernel_type, fn_name, arch, minmax) 533 534 535def print_generator_definition(kernel_type, fn_name, arch, minmax, prefetch=''): 536 if kernel_type == GEMM: 537 print(f'xnn_status {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {{') 538 else: 539 print(f'xnn_status {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {{') 540 print(f' using namespace xnnpack::{arch};') 541 print(' Generator g(code);') 542 if minmax: 543 print(' assert(params != nullptr);') 544 print(' const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);') 545 if kernel_type == GEMM: 546 if minmax: 547 print(f' g.generate({prefetch}max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);') 548 else: 549 print(f' g.generate({prefetch}max_mr, nc_mod_nr, kc, nullptr);') 550 else: 551 if minmax: 552 print(f' g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);') 553 else: 554 print(f' g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, nullptr);') 555 print(' g.finalize();') 556 print(' if (g.error() != xnnpack::Error::kNoError) {') 557 print(' return xnn_status_invalid_state;') 558 print(' }') 559 print(' return xnn_status_success;') 560 print('}') 561 562 563if __name__ == '__main__': 564 parser = argparse.ArgumentParser(description='Convert assembly to to JIT C++, writes to stdout.') 565 parser.add_argument('input_file', help='Input assembly filename') 566 args = parser.parse_args() 567 main(args.input_file) 568