xref: /aosp_15_r20/external/mesa3d/src/panfrost/compiler/valhall/asm.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#encoding=utf-8
2
3# Copyright (C) 2021 Collabora, Ltd.
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# the rights to use, copy, modify, merge, publish, distribute, sublicense,
9# and/or sell copies of the Software, and to permit persons to whom the
10# Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22# IN THE SOFTWARE.
23
24import argparse
25import sys
26import struct
27from valhall import valhall_parse_isa
28
29(instructions, immediates, enums, typesize, safe_name) = valhall_parse_isa()
30
31LINE = ''
32
33class ParseError(Exception):
34    def __init__(self, error):
35        self.error = error
36
37class FAUState:
38    def __init__(self, message = False):
39        self.message = message
40        self.page = None
41        self.words = set()
42        self.buffer = set()
43
44    def set_page(self, page):
45        assert(page <= 3)
46        die_if(self.page is not None and self.page != page, 'Mismatched pages')
47        self.page = page
48
49    def push(self, source):
50        if not (source & (1 << 7)):
51            # Skip registers
52            return
53
54        self.buffer.add(source)
55        die_if(len(self.buffer) > 2, "Overflowed FAU buffer")
56
57        if (source >> 5) == 0b110:
58            # Small constants need to check if the buffer overflows but no else
59            return
60
61        slot = (source >> 1)
62
63        self.words.add(source)
64
65        # Check the encoded slots
66        slots = set([(x >> 1) for x in self.words])
67        die_if(len(slots) > (2 if self.message else 1), 'Too many FAU slots')
68        die_if(len(self.words) > (3 if self.message else 2), 'Too many FAU words')
69
70# When running standalone, exit with the error since we're dealing with a
71# human. Otherwise raise a Python exception so the test harness can handle it.
72def die(s):
73    if __name__ == "__main__":
74        print(LINE)
75        print(s)
76        sys.exit(1)
77    else:
78        raise ParseError(s)
79
80def die_if(cond, s):
81    if cond:
82        die(s)
83
84def parse_int(s, minimum, maximum):
85    try:
86        number = int(s, base = 0)
87    except ValueError:
88        die(f"Expected number {s}")
89
90    if number > maximum or number < minimum:
91        die(f"Range error on {s}")
92
93    return number
94
95def encode_source(op, fau):
96    if op[0] == '^':
97        die_if(op[1] != 'r', f"Expected register after discard {op}")
98        return parse_int(op[2:], 0, 63) | 0x40
99    elif op[0] == 'r':
100        return parse_int(op[1:], 0, 63)
101    elif op[0] == 'u':
102        val = parse_int(op[1:], 0, 127)
103        fau.set_page(val >> 6)
104        return (val & 0x3F) | 0x80
105    elif op[0] == 'i':
106        return int(op[3:]) | 0xC0
107    elif op.startswith('0x'):
108        try:
109            val = int(op, base=0)
110        except ValueError:
111            die('Expected value')
112
113        die_if(val not in immediates, 'Unexpected immediate value')
114        return immediates.index(val) | 0xC0
115    else:
116        for i in [0, 1, 3]:
117            if op in enums[f'fau_special_page_{i}'].bare_values:
118                idx = 32 + (enums[f'fau_special_page_{i}'].bare_values.index(op) << 1)
119                fau.set_page(i)
120                return idx | 0xC0
121
122        die('Invalid operand')
123
124
125def encode_dest(op):
126    die_if(op[0] != 'r', f"Expected register destination {op}")
127
128    parts = op.split(".")
129    reg = parts[0]
130
131    # Default to writing in full
132    wrmask = 0x3
133
134    if len(parts) > 1:
135        WMASKS = ["h0", "h1"]
136        die_if(len(parts) > 2, "Too many modifiers")
137        mask = parts[1];
138        die_if(mask not in WMASKS, "Expected a write mask")
139        wrmask = 1 << WMASKS.index(mask)
140
141    return parse_int(reg[1:], 0, 63) | (wrmask << 6)
142
143def parse_asm(line):
144    global LINE
145    LINE = line # For better errors
146    encoded = 0
147
148    # Figure out mnemonic
149    head = line.split(" ")[0]
150    opts = [ins for ins in instructions if head.startswith(ins.name)]
151    opts = sorted(opts, key=lambda x: len(x.name), reverse=True)
152
153    if len(opts) == 0:
154        die(f"No known mnemonic for {head}")
155
156    if len(opts) > 1 and len(opts[0].name) == len(opts[1].name):
157        print(f"Ambiguous mnemonic for {head}")
158        print(f"Options:")
159        for ins in opts:
160            print(f"  {ins}")
161        sys.exit(1)
162
163    ins = opts[0]
164
165    # Split off modifiers
166    if len(head) > len(ins.name) and head[len(ins.name)] != '.':
167        die(f"Expected . after instruction in {head}")
168
169    mods = head[len(ins.name) + 1:].split(".")
170    modifier_map = {}
171
172    tail = line[(len(head) + 1):]
173    operands = [x.strip() for x in tail.split(",") if len(x.strip()) > 0]
174    expected_op_count = len(ins.srcs) + len(ins.dests) + len(ins.immediates) + len(ins.staging)
175    if len(operands) != expected_op_count:
176        die(f"Wrong number of operands in {line}, expected {expected_op_count}, got {len(operands)} {operands}")
177
178    # Encode each operand
179    for i, (op, sr) in enumerate(zip(operands, ins.staging)):
180        die_if(op[0] != '@', f'Expected staging register, got {op}')
181        parts = op[1:].split(':')
182
183        if op == '@':
184            parts = []
185
186        die_if(any([x[0] != 'r' for x in parts]), f'Expected registers, got {op}')
187        regs = [parse_int(x[1:], 0, 63) for x in parts]
188
189        extended_write = "staging_register_write_count" in [x.name for x in ins.modifiers] and sr.write
190        max_sr_count = 8 if extended_write else 7
191
192        sr_count = len(regs)
193        die_if(sr_count > max_sr_count, f'Too many staging registers {sr_count}')
194
195        base = regs[0] if len(regs) > 0 else 0
196        die_if(any([reg != (base + i) for i, reg in enumerate(regs)]),
197                'Expected consecutive staging registers, got {op}')
198        die_if(sr_count > 1 and (base % 2) != 0,
199                'Consecutive staging registers must be aligned to a register pair')
200
201        if sr.count == 0:
202            if "staging_register_write_count" in [x.name for x in ins.modifiers] and sr.write:
203                modifier_map["staging_register_write_count"] = sr_count - 1
204            else:
205                assert "staging_register_count" in [x.name for x in ins.modifiers]
206                modifier_map["staging_register_count"] = sr_count
207        else:
208            die_if(sr_count != sr.count, f"Expected {sr.count} staging registers, got {sr_count}")
209
210        encoded |= ((sr.encoded_flags | base) << sr.start)
211    operands = operands[len(ins.staging):]
212
213    for op, dest in zip(operands, ins.dests):
214        encoded |= encode_dest(op) << 40
215    operands = operands[len(ins.dests):]
216
217    if len(ins.dests) == 0 and len(ins.staging) == 0:
218        # Set a placeholder writemask to prevent encoding faults
219        encoded |= (0xC0 << 40)
220
221    fau = FAUState(message = ins.message)
222
223    for i, (op, src) in enumerate(zip(operands, ins.srcs)):
224        parts = op.split('.')
225        encoded_src = encode_source(parts[0], fau)
226
227        # Require a word selection for special FAU values
228        needs_word_select = ((encoded_src >> 5) == 0b111)
229
230        # Has a swizzle been applied yet?
231        swizzled = False
232
233        for mod in parts[1:]:
234            # Encode the modifier
235            if mod in src.offset and src.bits[mod] == 1:
236                encoded |= (1 << src.offset[mod])
237            elif src.halfswizzle and mod in enums[f'half_swizzles_{src.size}_bit'].bare_values:
238                die_if(swizzled, "Multiple swizzles specified")
239                swizzled = True
240                val = enums[f'half_swizzles_{src.size}_bit'].bare_values.index(mod)
241                encoded |= (val << src.offset['widen'])
242            elif mod in enums[f'swizzles_{src.size}_bit'].bare_values and (src.widen or src.lanes):
243                die_if(swizzled, "Multiple swizzles specified")
244                swizzled = True
245                val = enums[f'swizzles_{src.size}_bit'].bare_values.index(mod)
246                encoded |= (val << src.offset['widen'])
247            elif src.lane and mod in enums[f'lane_{src.size}_bit'].bare_values:
248                die_if(swizzled, "Multiple swizzles specified")
249                swizzled = True
250                val = enums[f'lane_{src.size}_bit'].bare_values.index(mod)
251                encoded |= (val << src.offset['lane'])
252            elif src.combine and mod in enums['combine'].bare_values:
253                die_if(swizzled, "Multiple swizzles specified")
254                swizzled = True
255                val = enums['combine'].bare_values.index(mod)
256                encoded |= (val << src.offset['combine'])
257            elif src.size == 32 and mod in enums['widen'].bare_values:
258                die_if(not src.swizzle, "Instruction doesn't take widens")
259                die_if(swizzled, "Multiple swizzles specified")
260                swizzled = True
261                val = enums['widen'].bare_values.index(mod)
262                encoded |= (val << src.offset['swizzle'])
263            elif src.size == 16 and mod in enums['swizzles_16_bit'].bare_values:
264                die_if(not src.swizzle, "Instruction doesn't take swizzles")
265                die_if(swizzled, "Multiple swizzles specified")
266                swizzled = True
267                val = enums['swizzles_16_bit'].bare_values.index(mod)
268                encoded |= (val << src.offset['swizzle'])
269            elif mod in enums['lane_8_bit'].bare_values:
270                die_if(not src.lane, "Instruction doesn't take a lane")
271                die_if(swizzled, "Multiple swizzles specified")
272                swizzled = True
273                val = enums['lane_8_bit'].bare_values.index(mod)
274                encoded |= (val << src.lane)
275            elif mod in enums['lanes_8_bit'].bare_values:
276                die_if(not src.lanes, "Instruction doesn't take a lane")
277                die_if(swizzled, "Multiple swizzles specified")
278                swizzled = True
279                val = enums['lanes_8_bit'].bare_values.index(mod)
280                encoded |= (val << src.offset['widen'])
281            elif mod in ['w0', 'w1']:
282                # Chck for special
283                die_if(not needs_word_select, 'Unexpected word select')
284
285                if mod == 'w1':
286                    encoded_src |= 0x1
287
288                needs_word_select = False
289            else:
290                die(f"Unknown modifier {mod}")
291
292        # Encode the identity if a swizzle is required but not specified
293        if src.swizzle and not swizzled and src.size == 16:
294            mod = enums['swizzles_16_bit'].default
295            val = enums['swizzles_16_bit'].bare_values.index(mod)
296            encoded |= (val << src.offset['swizzle'])
297        elif src.widen and not swizzled and src.size == 16:
298            die_if(swizzled, "Multiple swizzles specified")
299            mod = enums['swizzles_16_bit'].default
300            val = enums['swizzles_16_bit'].bare_values.index(mod)
301            encoded |= (val << src.offset['widen'])
302
303        encoded |= encoded_src << src.start
304        fau.push(encoded_src)
305
306    operands = operands[len(ins.srcs):]
307
308    for i, (op, imm) in enumerate(zip(operands, ins.immediates)):
309        if op[0] == '#':
310            die_if(imm.name != 'constant', "Wrong syntax for immediate")
311            parts = [imm.name, op[1:]]
312        else:
313            parts = op.split(':')
314            die_if(len(parts) != 2, f"Wrong syntax for immediate, wrong number of colons in {op}")
315            die_if(parts[0] != imm.name, f"Wrong immediate, expected {imm.name}, got {parts[0]}")
316
317        if imm.signed:
318            minimum = -(1 << (imm.size - 1))
319            maximum = +(1 << (imm.size - 1)) - 1
320        else:
321            minimum = 0
322            maximum = (1 << imm.size) - 1
323
324        val = parse_int(parts[1], minimum, maximum)
325
326        if val < 0:
327            # Sign extends
328            val = (1 << imm.size) + val
329
330        encoded |= (val << imm.start)
331
332    operands = operands[len(ins.immediates):]
333
334    # Encode the operation itself
335    encoded |= (ins.opcode << 48)
336    encoded |= (ins.opcode2 << ins.secondary_shift)
337
338    # Encode FAU page
339    if fau.page:
340        encoded |= (fau.page << 57)
341
342    # Encode modifiers
343    has_flow = False
344    for mod in mods:
345        if len(mod) == 0:
346            continue
347
348        if mod in enums['flow'].bare_values:
349            die_if(has_flow, "Multiple flow control modifiers specified")
350            has_flow = True
351            encoded |= (enums['flow'].bare_values.index(mod) << 59)
352        else:
353            candidates = [c for c in ins.modifiers if mod in c.bare_values]
354
355            die_if(len(candidates) == 0, f"Invalid modifier {mod} used")
356            assert(len(candidates) == 1) # No ambiguous modifiers
357            opts = candidates[0]
358
359            value = opts.bare_values.index(mod)
360            assert(value is not None)
361
362            die_if(opts.name in modifier_map, f"{opts.name} specified twice")
363            modifier_map[opts.name] = value
364
365    for mod in ins.modifiers:
366        value = modifier_map.get(mod.name, mod.default)
367        die_if(value is None, f"Missing required modifier {mod.name}")
368
369        assert(value < (1 << mod.size))
370        encoded |= (value << mod.start)
371
372    return encoded
373
374if __name__ == "__main__":
375    # Provide commandline interface
376    parser = argparse.ArgumentParser(description='Assemble Valhall shaders')
377    parser.add_argument('infile', nargs='?', type=argparse.FileType('r'),
378                        default=sys.stdin)
379    parser.add_argument('outfile', type=argparse.FileType('wb'))
380    args = parser.parse_args()
381
382    lines = args.infile.read().strip().split('\n')
383    lines = [l for l in lines if len(l) > 0 and l[0] != '#']
384
385    packed = b''.join([struct.pack('<Q', parse_asm(ln)) for ln in lines])
386    args.outfile.write(packed)
387