xref: /aosp_15_r20/external/gemmlowp/standalone/encode.py (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1# Copyright 2018 The gemmlowp Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Encodes ARM asm code for certain instructions into the corresponding machine code encoding, as a .word directive in the asm code, preserving the original code in a comment.
16
17Reads from stdin, writes to stdout.
18
19Example diff:
20-        "udot v16.4s, v4.16b, v0.16b\n"
21+        ".word 0x6e809490  // udot v16.4s, v4.16b, v0.16b\n"
22
23The intended use case is to make asm code easier to compile on toolchains that
24do not support certain new instructions.
25"""
26
27import sys
28import re
29import argparse
30
31
32def encode_udot_sdot_vector(line):
33  m = re.search(
34      r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b',
35      line)
36  if not m:
37    return 0, line
38
39  match = m.group(0)
40  unsigned = 1 if m.group(1) == 'u' else 0
41  accum = int(m.group(2))
42  lhs = int(m.group(3))
43  rhs = int(m.group(4))
44  assert accum >= 0 and accum <= 31
45  assert lhs >= 0 and lhs <= 31
46  assert rhs >= 0 and rhs <= 31
47  mcode = 0x4e809400 | (accum << 0) | (lhs << 5) | (rhs << 16) | (
48      unsigned << 29)
49  return mcode, match
50
51
52def encode_udot_sdot_element(line):
53  m = re.search(
54      r'\b([us])dot[ ]+v([0-9]+)[ ]*\.[ ]*4s[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*16b[ ]*\,[ ]*v([0-9]+)[ ]*\.[ ]*4b[ ]*\[([0-9])\]',
55      line)
56  if not m:
57    return 0, line
58
59  match = m.group(0)
60  unsigned = 1 if m.group(1) == 'u' else 0
61  accum = int(m.group(2))
62  lhs = int(m.group(3))
63  rhs = int(m.group(4))
64  lanegroup = int(m.group(5))
65  assert accum >= 0 and accum <= 31
66  assert lhs >= 0 and lhs <= 31
67  assert rhs >= 0 and rhs <= 31
68  assert lanegroup >= 0 and lanegroup <= 3
69  l = 1 if lanegroup & 1 else 0
70  h = 1 if lanegroup & 2 else 0
71  mcode = 0x4f80e000 | (accum << 0) | (lhs << 5) | (rhs << 16) | (l << 21) | (
72      h << 11) | (
73          unsigned << 29)
74  return mcode, match
75
76
77def encode(line):
78  for encode_func in [encode_udot_sdot_vector, encode_udot_sdot_element]:
79    mcode, match = encode_func(line)
80    if mcode:
81      return mcode, match
82  return 0, line
83
84
85def read_existing_encoding(line):
86  m = re.search(r'\.word\ (0x[0-9a-f]+)', line)
87  if m:
88    return int(m.group(1), 16)
89  return 0
90
91
92parser = argparse.ArgumentParser(description='Encode some A64 instructions.')
93parser.add_argument(
94    '-f',
95    '--fix',
96    help='fix existing wrong encodings in-place and continue',
97    action='store_true')
98args = parser.parse_args()
99
100lineno = 0
101found_existing_encodings = False
102found_error = False
103found_fixes = False
104for line in sys.stdin:
105  lineno = lineno + 1
106  mcode, match = encode(line)
107  if mcode:
108    existing_encoding = read_existing_encoding(line)
109    if existing_encoding:
110      found_existing_encodings = True
111      if mcode != existing_encoding:
112        if args.fix:
113          line = line.replace('.word 0x%x  // %s' % (existing_encoding, match),
114                              '.word 0x%x  // %s' % (mcode, match))
115          found_fixes = True
116        else:
117          sys.stderr.write(
118              "Error at line %d: existing encoding 0x%x differs from encoding 0x%x for instruction '%s':\n\n%s\n\n"
119              % (lineno, existing_encoding, mcode, match, line))
120          found_error = True
121    else:
122      line = line.replace(match, '.word 0x%x  // %s' % (mcode, match))
123  sys.stdout.write(line)
124if found_error:
125  sys.exit(1)
126if found_existing_encodings:
127  if found_fixes:
128    sys.stderr.write(
129        'Note: some instructions that this program is able to encode, were already encoded and their existing encodings didn\'t match the specified asm instructions. Since --fix was passed, these were fixed in-place.\n'
130    )
131  else:
132    sys.stderr.write(
133        'Note: some instructions that this program is able to encode, were already encoded. These encodings have been checked.\n'
134    )
135