xref: /aosp_15_r20/external/XNNPACK/scripts/convert-assembly-to-jit.py (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1#!/usr/bin/env python3
2# Copyright 2021 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6"""Converts hand written assembly (.S files) to C++ files using the JIT.
7
8Takes a single argument, an assembly file, and prints converted output to stdout.
9"""
10
11import argparse
12import datetime
13import re
14import sys
15
16SPACES = r'\s*'
17COMMA = r',' + SPACES
18COMMENTS = SPACES + '((//\s+.+)|)$'
19WB = r'!'
20
21REG_NO_GROUP = r'r\d+|s\d+|d\d+|q\d+|sp|lr|pc|x\d+|(?:v\d+\.(?:\d+)?(?:d|s|h|b))'
22REG = r'(' + REG_NO_GROUP + ')'
23IMM_NO_GROUP = r'\d+'
24IMM = r'(' + IMM_NO_GROUP + ')'
25REG_LANE_NO_GROUP = r'(?:' + REG_NO_GROUP + r')\[' + IMM_NO_GROUP + r'\]'
26REG_OR_IMM = r'(' + REG_LANE_NO_GROUP + '|' + REG_NO_GROUP + '|' + IMM_NO_GROUP + ')'
27
28REGLIST_CONSEC = r'\{(\w+)-(\w+)\}' + SPACES
29REGLIST_INDIV = r'\{([\w.]+(?:,\s+[\w.]+)*)\}' + SPACES
30REGLIST_INDIV_REPLICATE = r'\{(\w+(?:\[\])(,\s*\w+(?:\[\]))*)\}' + SPACES
31REGLIST_INDEX = r'\{(' + REG_LANE_NO_GROUP + ')\}' + SPACES
32
33APSR = 'APSR_nzcv'
34FPSCR = '(FPSCR)'
35
36MEMOP = r'\[' + SPACES + REG + '\]' + SPACES
37MEMOP_MAYBE_WB = r'\[' + SPACES + REG + '\]' + f'({WB})?'
38MEMOP_OFFSET = r'\[' + REG + COMMA + '(-?\d+)\]' + SPACES
39MEMOP_OFFSET_MAYBE_WB = r'\[' + REG + COMMA + '(-?\d+)\]' + f'({WB})?' + SPACES
40
41B_IMM = r'(\d+)(f|b)'
42
43INSTR = SPACES + r'([A-Z0-9.]+)' + SPACES
44
45# e.g. #ifndef __APPLE__
46IFDEF_RE = re.compile(r'\s*#(ifndef|endif|ifdef).*')
47# e.g. # Push 96 bytes
48COMMENT_RE = re.compile(SPACES + r'((//|#)\s*.+)')
49# e.g. 0:
50LABEL = re.compile(r'(\w+):')
51# e.g. NOP
52INSTR_RE = re.compile(INSTR + COMMENTS)
53# e.g. VPUSH {d8-d15}
54INSTR_REGLIST_CONSEC_RE = re.compile(INSTR + REGLIST_CONSEC + COMMENTS)
55# e.g. PUSH {r4, r5}
56INSTR_REGLIST_LIST_RE = re.compile(INSTR + REGLIST_INDIV + COMMENTS)
57# e.g. BX lr
58INSTR_OP_RE = re.compile(INSTR + REG + COMMENTS)
59# e.g. BLO 2f
60INSTR_B_IMM = re.compile(INSTR + B_IMM + COMMENTS)
61# e.g. TBNZ x0, 4, 5f
62INSTR_B_REG_IMM_IMM = re.compile(INSTR + REG + COMMA + IMM + COMMA + B_IMM + COMMENTS)
63# e.g. .p2align 3
64P2ALIGN_RE = re.compile(SPACES + r'\.p2align\s+(\d+)')
65# e.g. CMP r0, 2
66INSTR_REG_IMM_RE = re.compile(INSTR + REG + COMMA + IMM + COMMENTS)
67# e.g. LDR r0, [r12]
68INSTR_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMENTS)
69# e.g. LDR q0, [x4], 16
70INSTR_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMA + IMM + COMMENTS)
71# e.g. LDR r0, [sp, 112], STR x20, [sp, -80]!
72INSTR_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + MEMOP_OFFSET_MAYBE_WB +
73                                       COMMENTS)
74# e.g. LDRD r6, r7, [sp]
75INSTR_REG_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + REG + COMMA +
76                                           MEMOP + COMMENTS)
77# e.g. LDRD r6, r7, [sp, 104], STP d8, d9, [sp, -64]!
78INSTR_REG_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + REG + COMMA +
79                                           MEMOP_OFFSET_MAYBE_WB + COMMENTS)
80# e.g. LDP q20, q21, [x5], 32
81INSTR_REG_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA +
82                                           MEMOP + COMMA + IMM + COMMENTS)
83# e.g. PLD [r4, 64]
84INSTR_MEMOP_OFFSET_RE = re.compile(INSTR + MEMOP_OFFSET + COMMENTS)
85# e.g. movlo r12, r3, vdup.32 q0, d14[0]
86INSTR_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG_OR_IMM + COMMENTS)
87# e.g. SUBS r5, r2, 16 or SUBS r5, r2, r10 or VMLFA.F32 q8, q4, d0[0]
88INSTR_REG_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG + COMMA +
89                                  REG_OR_IMM + COMMENTS)
90# e.g. VEXT.8  q0, q0, q0, 4
91INSTR_REG_REG_REG_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG +
92                                      COMMA + IMM + COMMENTS)
93# e.g. VST1.32 {d16}, [r11], r0
94INSTR_REGLIST_INDIV_MEMOP_REG = re.compile(INSTR + REGLIST_INDIV + COMMA +
95                                           MEMOP + COMMA + REG + COMMENTS)
96# e.g. VST1.32 {d16-d19}, [r11], r0
97INSTR_REGLIST_CONSEC_MEMOP_REG = re.compile(INSTR + REGLIST_CONSEC + COMMA +
98                                            MEMOP + COMMA + REG + COMMENTS)
99# e.g. VLDM r9, {d16-d19}
100INSTR_REG_REGLIST_CONSECT = re.compile(INSTR + REG + COMMA + REGLIST_CONSEC +
101                                       COMMENTS)
102# e.g. VLDM r9!, {d16-d19}
103INSTR_REG_REGLIST_CONSECT_WB = re.compile(INSTR + REG + WB + COMMA +
104                                          REGLIST_CONSEC + COMMENTS)
105# e.g. VLDM r9!, {d16}
106INSTR_REG_REGLIST_INDIV_WB = re.compile(INSTR + REG + WB + COMMA +
107                                        REGLIST_INDIV + COMMENTS)
108# e.g. VLD1.32 {d0}, [r3]{!}
109INSTR_REGLIST_INDIV_MEMOP = re.compile(INSTR + REGLIST_INDIV + COMMA +
110                                       MEMOP_MAYBE_WB + COMMENTS)
111# e.g. LD1 {v16.16b, v17.16b, v18.16b}, [x5], 48
112INSTR_REGLIST_INDIV_MEMOP_IMM = re.compile(INSTR + REGLIST_INDIV + COMMA +
113                                       MEMOP + COMMA + IMM + COMMENTS)
114# e.g. VST1.32 {d24-d25}, [r11]{!}
115INSTR_REGLIST_CONSEC_MEMOP = re.compile(INSTR + REGLIST_CONSEC + COMMA +
116                                        MEMOP_MAYBE_WB + COMMENTS)
117# e.g. VLD1.32 {d0[]}, [r3]!
118INSTR_REGLIST_REPLICATE_MEMOP = re.compile(INSTR + REGLIST_INDIV_REPLICATE +
119                                           COMMA + MEMOP + r'(!)?' + COMMENTS)
120# e.g. VST1.32 {d16[0]}, [r11]{!}
121INSTR_REGLIST_INDEX_MEMOP = re.compile(INSTR + REGLIST_INDEX + COMMA +
122                                       MEMOP_MAYBE_WB + COMMENTS)
123# e.g. VMRS APSR_nzcv, FPSCR
124INSTR_REG_FPSCR = re.compile(INSTR + f'({APSR}|{REG_NO_GROUP})' + COMMA +
125                             FPSCR + COMMENTS)
126
127# e.g. PRFM PLDL1KEEP, [x5]
128INSTR_PLD_MEMOP = re.compile(INSTR + f'(PLDL1KEEP)' + COMMA + MEMOP + COMMENTS)
129# e.g. PRFM PLDL1KEEP, [x5, 64]
130INSTR_PLD_MEMOP_OFFSET = re.compile(INSTR + f'(PLDL1KEEP)' + COMMA + MEMOP_OFFSET + COMMENTS)
131
132COND = r'([A-Z]+)'
133# e.g. CSEL x9, x3, x9, LO
134INSTR_REG_REG_REG_COND_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG + COMMA + COND + COMMENTS)
135
136
137def remove_brackets(s):
138  return s.replace('[', '').replace(']', '')
139
140
141def fix_replicate_instruction(s):
142  return re.sub(r'_(\d+)', r'r_\1', s, 1)
143
144
145def fix_instr_name(s):
146  return s.lower().replace('.', '_', 2).replace('and', 'and_', 1)
147
148
149def fix_comments(s):
150  return s.replace('#', '//', 1)
151
152
153def maybe_wb(wb):
154  return '++' if wb else ''
155
156
157def fix_fn_name(name):
158  if name.startswith('xnn_'):
159    name = name[len('xnn_'):]
160  # remove any type of activations from name
161  if 'minmax' in name:
162    name = name.replace('minmax_', '')
163  return f'xnn_generate_{name}'
164
165
166def remove_prfm_from_fn_name(name):
167  assert('_prfm_' in name)
168  return name.replace('prfm_', '')
169
170
171def fix_regs(regs):
172  # Vector registers with datatype need to be method calls.
173  # e.g. v2.4s -> v2.v4s(), v2.s -> v2.s()
174  def repl(m):
175    if m.group(2):
176      return f'{m[1]}v{m[2]}{m[3]}()'
177    else:
178      return f'{m[1]}{m[3]}()'
179  return re.sub(r'(\w+\.)(\d+)?(\w+)', repl, regs)
180
181
182IGNORE_LINES = [r'\s*\.\w+']
183
184AARCH32 = 'aarch32'
185AARCH64 = 'aarch64'
186GEMM = 'GEMM'
187IGEMM = 'IGEMM'
188
189def main(input_file):
190  arch = None
191  kernel_type = GEMM
192  minmax = False
193  prfm = False
194  datatype = 'f32'
195  ctype = 'float'
196
197  if 'aarch32' in input_file:
198    arch = AARCH32
199  elif 'aarch64' in input_file:
200    arch = AARCH64
201  else:
202    print('ERROR: unknown architecture')
203    sys.exit(1)
204
205  if 'igemm' in input_file:
206    kernel_type = IGEMM
207  if 'minmax' in input_file:
208    minmax = True
209  if 'prfm' in input_file:
210    prfm = True
211
212  # Whether we are in the copyright section.
213  in_copyright = False
214  # Whether we are in the microkernel function.
215  in_function = False
216  # Instructions that make up the microkernel.
217  instructions = []
218  # Lines of code or comments before the actual function body.
219  prologue = []
220  # All labels need to be declared first, collect them and output them after
221  # function signature.
222  labels = []
223  # Name of the microkernel function.
224  fn_name = ''
225  sc = ';'
226  # Whether we are in the auto-generated comment.
227  in_autogen = False
228
229  with open(input_file, 'r', encoding='utf-8') as f:
230    for line in f:
231      line = line.rstrip()
232
233      # Handle all lines before the microkernel instructions begin.
234      if not in_function:
235        if 'Auto-generated file' in line:
236          in_autogen = True
237          continue
238        elif 'BEGIN_FUNCTION' in line:
239          in_function = True
240          fn_name = line.split()[1]
241          prologue.append(f'// Converted from: {input_file[20:]}')
242          params = 'float min, float max' if minmax else 'void* params'
243          prefetch = 'bool prefetch, ' if prfm else ''
244          if kernel_type == GEMM:
245            prologue.append(f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params}) {{')
246          else:
247            prologue.append(f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params}) {{')
248          continue
249        elif 'Copyright ' in line:
250          in_autogen = False
251          # replace year
252          prologue.append(
253              re.sub('\d{4}', str(datetime.date.today().year), line,
254                     1).rstrip())
255          continue
256        elif '#include <xnnpack/assembly.h>' in line:
257          prologue.append(f'#include <cassert>')
258          prologue.append(f'#include <cstddef>')
259          prologue.append(f'#include <limits>')
260          prologue.append('')
261          prologue.append(f'#include <xnnpack/{arch}-assembler.h>')
262          prologue.append('#include <xnnpack/allocator.h>')
263          if kernel_type == GEMM:
264            prologue.append('#include <xnnpack/gemm.h>')
265          else:
266            prologue.append('#include <xnnpack/igemm.h>')
267          prologue.append('')
268          prologue.append('namespace xnnpack {')
269          prologue.append(f'namespace {arch} {{')
270          prologue.append('namespace {')
271          prologue.append('class Generator : public Assembler {')
272          prologue.append('  using Assembler::Assembler;')
273          prologue.append(' public:')
274          params = 'float min, float max' if minmax else 'void* params'
275          prefetch = 'bool prefetch, ' if prfm else ''
276          if kernel_type == GEMM:
277            prologue.append(f'  void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params});')
278          else:
279            prologue.append(f'  void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params});')
280          prologue.append('};')
281          continue
282        elif any(re.fullmatch(p, line) for p in IGNORE_LINES):
283          continue
284        elif in_autogen:
285          continue
286        else:
287          prologue.append(fix_comments(line.rstrip()))
288          continue
289      # end if not in_function
290
291      # We are now in the microkernel function body.
292      # Don't keep the ifdefs.
293      m = re.fullmatch(IFDEF_RE, line)
294      if m:
295        continue
296      # But keep other comments.
297      m = re.fullmatch(COMMENT_RE, line)
298      if m:
299        instructions.append(m[1])
300        continue
301
302      m = re.fullmatch(LABEL, line)
303      if m:
304        labels.append(m[1])
305        instructions.append(f'bind(l{m[1]}){sc}')
306        continue
307      m = re.fullmatch(INSTR_RE, line)
308      if m:
309        instructions.append(f'{fix_instr_name(m[1])}(){sc} {m[2]}')
310        continue
311      m = re.fullmatch(INSTR_OP_RE, line)
312      if m:
313        instructions.append(f'{fix_instr_name(m[1])}({m[2]}){sc} {m[3]}')
314        continue
315      m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP_REG, line)
316      if m:
317        instructions.append(
318            f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}], {m[5]}){sc} {m[6]}'
319        )
320        continue
321      m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_REG, line)
322      if m:
323        instructions.append(
324            f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}')
325        continue
326      m = re.fullmatch(INSTR_REGLIST_CONSEC_RE, line)
327      if m:
328        instructions.append(f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}){sc} {m[4]}')
329        continue
330      m = re.fullmatch(INSTR_REGLIST_LIST_RE, line)
331      if m:
332        instructions.append(f'{fix_instr_name(m[1])}({{{m[2]}}}){sc} {m[3]}')
333        continue
334      m = re.fullmatch(INSTR_MEMOP_OFFSET_RE, line)
335      if m:
336        instructions.append(f'{fix_instr_name(m[1])}(mem[{m[2]}, {m[3]}]){sc} {m[4]}')
337        continue
338      m = re.fullmatch(INSTR_REG_MEMOP_RE, line)
339      if m:
340        instructions.append(f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}]){sc} {m[4]}')
341        continue
342      m = re.fullmatch(INSTR_REG_MEMOP_IMM_RE , line)
343      if m:
344        instructions.append(f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}], {m[4]}){sc} {m[5]}')
345        continue
346      m = re.fullmatch(INSTR_REG_MEMOP_OFFSET_RE, line)
347      if m:
348        if m[5]: # wb
349          instructions.append(
350              f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]++){sc} {m[6]}')
351        else: # no wb
352          instructions.append(
353              f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[6]}')
354        continue
355      m = re.fullmatch(INSTR_REG_REG_MEMOP_RE, line)
356      if m:
357        instructions.append(
358            f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}]){sc} {m[5]}')
359        continue
360      m = re.fullmatch(INSTR_REG_REG_MEMOP_OFFSET_RE, line)
361      if m:
362        if m[6]: # wb
363          instructions.append(
364              f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]++){sc} {m[7]}')
365        else: #no wb
366          instructions.append(
367              f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]){sc} {m[7]}')
368        continue
369      m = re.fullmatch(INSTR_REG_REG_MEMOP_IMM_RE , line)
370      if m:
371        instructions.append(
372            f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}], {m[5]}){sc} {m[6]}')
373        continue
374      m = re.fullmatch(INSTR_REG_IMM_RE, line)
375      if m:
376        instructions.append(f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {m[3]}){sc} {m[4]}')
377        continue
378      m = re.fullmatch(INSTR_REG_REG_REG_RE, line)
379      if m:
380        instructions.append(
381            f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}, {fix_regs(m[4])}){sc} {m[5]}')
382        continue
383      m = re.fullmatch(INSTR_REG_REG_REG_IMM_RE, line)
384      if m:
385        instructions.append(
386            f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, {m[5]}){sc} {m[6]}')
387        continue
388      m = re.fullmatch(INSTR_REG_REG_RE, line)
389      if m:
390        instructions.append(f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}){sc} {m[4]}')
391        continue
392      m = re.fullmatch(INSTR_REG_REGLIST_CONSECT, line)
393      if m:
394        instructions.append(
395            f'{fix_instr_name(m[1])}(mem[{m[2]}], {{{m[3]}-{m[4]}}}){sc} {m[5]}')
396        continue
397      m = re.fullmatch(INSTR_REG_REGLIST_CONSECT_WB, line)
398      if m:
399        instructions.append(
400            f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}-{m[4]}}}){sc} {m[5]}')
401        continue
402      m = re.fullmatch(INSTR_REG_REGLIST_INDIV_WB, line)
403      if m:
404        instructions.append(
405            f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}}}){sc} {m[4]}')
406        continue
407      m = re.fullmatch(INSTR_B_IMM, line)
408      if m:
409        instructions.append(f'{fix_instr_name(m[1])}(l{m[2]}){sc} {m[4]}')
410        continue
411      m = re.fullmatch(INSTR_B_REG_IMM_IMM , line)
412      if m:
413        instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, l{m[4]}){sc} {m[6]}')
414        continue
415      m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP, line)
416      if m:
417        instructions.append(
418            f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}'
419        )
420        continue
421      m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_IMM, line)
422      if m:
423        instructions.append(
424            f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}'
425        )
426        continue
427      m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP, line)
428      if m:
429        instructions.append(
430            f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}]{maybe_wb(m[5])}){sc} {m[6]}'
431        )
432        continue
433      m = re.fullmatch(INSTR_REGLIST_REPLICATE_MEMOP, line)
434      if m:
435        if m[5]:
436          instructions.append(
437              f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]++){sc} {m[6]}'
438          )
439        else:
440          instructions.append(
441              f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]){sc} {m[6]}'
442          )
443        continue
444      m = re.fullmatch(INSTR_REGLIST_INDEX_MEMOP, line)
445      if m:
446        instructions.append(
447            f'{fix_instr_name(m[1])}({{{m[2]}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}'
448        )
449        continue
450      m = re.fullmatch(P2ALIGN_RE, line)
451      if m:
452        instructions.append(f'align({1 << int(m[1])}){sc}')
453        continue
454      m = re.fullmatch(INSTR_REG_FPSCR, line)
455      if m:
456        instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}){sc} {m[4]}')
457        continue
458      m = re.fullmatch(INSTR_PLD_MEMOP, line)
459      if m:
460        instructions.append(f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}]){sc} {m[4]}')
461        continue
462      m = re.fullmatch(INSTR_PLD_MEMOP_OFFSET, line)
463      if m:
464        instructions.append(f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[5]}')
465        continue
466      m = re.fullmatch(INSTR_REG_REG_REG_COND_RE, line)
467      if m:
468        instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, k{m[5]}){sc} {m[6]}')
469        continue
470
471      # Keep empty lines for formatting
472      if line.strip() == '':
473        instructions.append('')
474        continue
475
476      # Assembly directives that we don't are about.
477      if line.strip().startswith('.'):
478        continue
479
480      if line.startswith('END_FUNCTION'):
481        break
482
483      # All other lines are error.
484      print(f'ERROR: {line}', file=sys.stderr)
485      sys.exit(1)
486
487  # Actually emit the JIT codegen (to stdout).
488  for p in prologue:
489    print(p)
490
491
492  m = re.search('(\d+)x(\d+)', input_file)
493  mr = 0
494  nr = 0
495  if m:
496    mr = m[1]
497    nr = m[2]
498  labels_str = ', '.join(f'l{l}' for l in labels)
499  print(f'  assert(max_mr <= {mr});')
500  print(f'  assert(nc_mod_nr < {nr});')
501  print('  assert(kc != 0);')
502  print(f'  assert(kc % sizeof({ctype}) == 0);')
503  print()
504  print(f'  Label {labels_str};')
505  print()
506  if minmax:
507    print('  // const bool clamp_min = min != -std::numeric_limits<float>::infinity();')
508    print('  // const bool clamp_max = max != +std::numeric_limits<float>::infinity();')
509
510  indent = '  '
511  for i in instructions:
512    if i.strip().startswith('#'):
513      print(indent + fix_comments(i))
514    elif i.strip().startswith('//'):
515      print(indent + i)
516    elif i.strip() == '':
517      print()
518    else:
519      print(indent + (i).rstrip())
520  print(indent + 'align(16, AlignInstruction::kHlt);')
521
522  print('}')
523  print('}  // namespace')
524  print(f'}}  // {arch}')
525  print('}  // xnnpack')
526  print('')
527  if prfm:
528    print_generator_definition(kernel_type, remove_prfm_from_fn_name(fn_name), arch, minmax, prefetch='false, ')
529    print()
530    print_generator_definition(kernel_type, fn_name, arch, minmax, prefetch='true, ')
531  else:
532    print_generator_definition(kernel_type, fn_name, arch, minmax)
533
534
535def print_generator_definition(kernel_type, fn_name, arch, minmax, prefetch=''):
536  if kernel_type == GEMM:
537    print(f'xnn_status {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {{')
538  else:
539    print(f'xnn_status {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {{')
540  print(f'  using namespace xnnpack::{arch};')
541  print('  Generator g(code);')
542  if minmax:
543    print('  assert(params != nullptr);')
544    print('  const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);')
545  if kernel_type == GEMM:
546    if minmax:
547      print(f'  g.generate({prefetch}max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);')
548    else:
549      print(f'  g.generate({prefetch}max_mr, nc_mod_nr, kc, nullptr);')
550  else:
551    if minmax:
552      print(f'  g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);')
553    else:
554      print(f'  g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, nullptr);')
555  print('  g.finalize();')
556  print('  if (g.error() != xnnpack::Error::kNoError) {')
557  print('    return xnn_status_invalid_state;')
558  print('  }')
559  print('  return xnn_status_success;')
560  print('}')
561
562
563if __name__ == '__main__':
564  parser = argparse.ArgumentParser(description='Convert assembly to to JIT C++, writes to stdout.')
565  parser.add_argument('input_file', help='Input assembly filename')
566  args = parser.parse_args()
567  main(args.input_file)
568