xref: /aosp_15_r20/external/mesa3d/src/intel/vulkan/grl/grl_metakernel_gen.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#!/bin/env python
2COPYRIGHT = """\
3/*
4 * Copyright 2021 Intel Corporation
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 */
26"""
27
28import argparse
29import os.path
30import re
31import sys
32
33from grl_parser import parse_grl_file
34
35class Writer(object):
36    def __init__(self, file):
37        self._file = file
38        self._indent = 0
39        self._new_line = True
40
41    def push_indent(self, levels=4):
42        self._indent += levels
43
44    def pop_indent(self, levels=4):
45        self._indent -= levels
46
47    def write(self, s, *fmt):
48        if self._new_line:
49            s = '\n' + s
50        self._new_line = False
51        if s.endswith('\n'):
52            self._new_line = True
53            s = s[:-1]
54        if fmt:
55            s = s.format(*fmt)
56        self._file.write(s.replace('\n', '\n' + ' ' * self._indent))
57
58# Internal Representation
59
60class Value(object):
61    def __init__(self, name=None, zone=None):
62        self.name = name
63        self._zone = zone
64        self.live = False
65
66    @property
67    def zone(self):
68        assert self._zone is not None
69        return self._zone
70
71    def is_reg(self):
72        return False
73
74    def c_val(self):
75        if not self.name:
76            print(self)
77        assert self.name
78        return self.name
79
80    def c_cpu_val(self):
81        assert self.zone == 'cpu'
82        return self.c_val()
83
84    def c_gpu_val(self):
85        if self.zone == 'gpu':
86            return self.c_val()
87        else:
88            return 'mi_imm({})'.format(self.c_cpu_val())
89
90class Constant(Value):
91    def __init__(self, value):
92        super().__init__(zone='cpu')
93        self.value = value
94
95    def c_val(self):
96        if self.value < 100:
97            return str(self.value)
98        elif self.value < (1 << 32):
99            return '0x{:x}u'.format(self.value)
100        else:
101            return '0x{:x}ull'.format(self.value)
102
103class Register(Value):
104    def __init__(self, name):
105        super().__init__(name=name, zone='gpu')
106
107    def is_reg(self):
108        return True
109
110class FixedGPR(Register):
111    def __init__(self, num):
112        super().__init__('REG{}'.format(num))
113        self.num = num
114
115    def write_c(self, w):
116        w.write('UNUSED struct mi_value {} = mi_reserve_gpr(&b, {});\n',
117                self.name, self.num)
118
119class GroupSizeRegister(Register):
120    def __init__(self, comp):
121        super().__init__('DISPATCHDIM_' + 'XYZ'[comp])
122        self.comp = comp
123
124class Member(Value):
125    def __init__(self, value, member):
126        super().__init__(zone=value.zone)
127        self.value = value
128        self.member = member
129
130    def is_reg(self):
131        return self.value.is_reg()
132
133    def c_val(self):
134        c_val = self.value.c_val()
135        if self.zone == 'gpu':
136            assert isinstance(self.value, Register)
137            if self.member == 'hi':
138                return 'mi_value_half({}, true)'.format(c_val)
139            elif self.member == 'lo':
140                return 'mi_value_half({}, false)'.format(c_val)
141            else:
142                assert False, 'Invalid member: {}'.format(self.member)
143        else:
144            return '.'.join([c_val, self.member])
145
146class OffsetOf(Value):
147    def __init__(self, mk, expr):
148        super().__init__(zone='cpu')
149        assert isinstance(expr, tuple) and expr[0] == 'member'
150        self.type = mk.m.get_type(expr[1])
151        self.field = expr[2]
152
153    def c_val(self):
154        return 'offsetof({}, {})'.format(self.type.c_name, self.field)
155
156class Scope(object):
157    def __init__(self, m, mk, parent):
158        self.m = m
159        self.mk = mk
160        self.parent = parent
161        self.defs = {}
162
163    def add_def(self, d, name=None):
164        if name is None:
165            name = d.name
166        assert name not in self.defs
167        self.defs[name] = d
168
169    def get_def(self, name):
170        if name in self.defs:
171            return self.defs[name]
172        assert self.parent, 'Unknown definition: "{}"'.format(name)
173        return self.parent.get_def(name)
174
175class Statement(object):
176    def __init__(self, srcs=[]):
177        assert isinstance(srcs, (list, tuple))
178        self.srcs = list(srcs)
179
180class SSAStatement(Statement, Value):
181    _count = 0
182
183    def __init__(self, zone, srcs):
184        Statement.__init__(self, srcs)
185        Value.__init__(self, None, zone)
186        self.c_name = '_tmp{}'.format(SSAStatement._count)
187        SSAStatement._count += 1
188
189    def c_val(self):
190        return self.c_name
191
192    def write_c_refs(self, w):
193        assert self.zone == 'gpu'
194        assert self.uses > 0
195        if self.uses > 1:
196            w.write('mi_value_add_refs(&b, {}, {});\n',
197                    self.c_name, self.uses - 1)
198
199class Half(SSAStatement):
200    def __init__(self, value, half):
201        assert half in ('hi', 'lo')
202        super().__init__(None, [value])
203        self.half = half
204
205    @property
206    def zone(self):
207        return self.srcs[0].zone
208
209    def write_c(self, w):
210        assert self.half in ('hi', 'lo')
211        if self.zone == 'cpu':
212            if self.half == 'hi':
213                w.write('uint32_t {} = (uint64_t)({}) >> 32;\n',
214                        self.c_name, self.srcs[0].c_cpu_val())
215            else:
216                w.write('uint32_t {} = {};\n',
217                        self.c_name, self.srcs[0].c_cpu_val())
218        else:
219            if self.half == 'hi':
220                w.write('struct mi_value {} = mi_value_half({}, true);\n',
221                        self.c_name, self.srcs[0].c_gpu_val())
222            else:
223                w.write('struct mi_value {} = mi_value_half({}, false);\n',
224                        self.c_name, self.srcs[0].c_gpu_val())
225            self.write_c_refs(w)
226
227class Expression(SSAStatement):
228    def __init__(self, mk, op, *srcs):
229        super().__init__(None, srcs)
230        self.op = op
231
232    @property
233    def zone(self):
234        zone = 'cpu'
235        for s in self.srcs:
236            if s.zone == 'gpu':
237                zone = 'gpu'
238        return zone
239
240    def write_c(self, w):
241        if self.zone == 'cpu':
242            c_cpu_vals = [s.c_cpu_val() for s in self.srcs]
243            # There is one bitfield that is a uint64_t, but only holds 2 bits.
244            # In practice we won't overflow, but let's help the compiler (and
245            # coverity) out here.
246            if self.op == '<<':
247                w.write(f'assume({c_cpu_vals[0]} < (1 << 8));')
248            w.write('uint64_t {} = ', self.c_name)
249            if len(self.srcs) == 1:
250                w.write('({} {})', self.op, c_cpu_vals[0])
251            elif len(self.srcs) == 2:
252                w.write('({} {} {})', c_cpu_vals[0], self.op, c_cpu_vals[1])
253            else:
254                assert len(self.srcs) == 3 and op == '?'
255                w.write('({} ? {} : {})', *c_cpu_vals)
256            w.write(';\n')
257            return
258
259        w.write('struct mi_value {} = ', self.c_name)
260        if self.op == '~':
261            w.write('mi_inot(&b, {});\n', self.srcs[0].c_gpu_val())
262        elif self.op == '+':
263            w.write('mi_iadd(&b, {}, {});\n',
264                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
265        elif self.op == '-':
266            w.write('mi_isub(&b, {}, {});\n',
267                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
268        elif self.op == '&':
269            w.write('mi_iand(&b, {}, {});\n',
270                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
271        elif self.op == '|':
272            w.write('mi_ior(&b, {}, {});\n',
273                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
274        elif self.op == '<<':
275            if self.srcs[1].zone == 'cpu':
276                w.write('mi_ishl_imm(&b, {}, {});\n',
277                        self.srcs[0].c_gpu_val(), self.srcs[1].c_cpu_val())
278            else:
279                w.write('mi_ishl(&b, {}, {});\n',
280                        self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
281        elif self.op == '>>':
282            if self.srcs[1].zone == 'cpu':
283                w.write('mi_ushr_imm(&b, {}, {});\n',
284                        self.srcs[0].c_gpu_val(), self.srcs[1].c_cpu_val())
285            else:
286                w.write('mi_ushr(&b, {}, {});\n',
287                        self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
288        elif self.op == '==':
289            w.write('mi_ieq(&b, {}, {});\n',
290                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
291        elif self.op == '<':
292            w.write('mi_ult(&b, {}, {});\n',
293                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
294        elif self.op == '>':
295            w.write('mi_ult(&b, {}, {});\n',
296                    self.srcs[1].c_gpu_val(), self.srcs[0].c_gpu_val())
297        elif self.op == '<=':
298            w.write('mi_uge(&b, {}, {});\n',
299                    self.srcs[1].c_gpu_val(), self.srcs[0].c_gpu_val())
300        else:
301            assert False, 'Unknown expression opcode: {}'.format(self.op)
302        self.write_c_refs(w)
303
304class StoreReg(Statement):
305    def __init__(self, mk, reg, value):
306        super().__init__([mk.load_value(value)])
307        self.reg = mk.parse_value(reg)
308        assert self.reg.is_reg()
309
310    def write_c(self, w):
311        value = self.srcs[0]
312        w.write('mi_store(&b, {}, {});\n',
313                self.reg.c_gpu_val(), value.c_gpu_val())
314
315class LoadMem(SSAStatement):
316    def __init__(self, mk, bit_size, addr):
317        super().__init__('gpu', [mk.load_value(addr)])
318        self.bit_size = bit_size
319
320    def write_c(self, w):
321        addr = self.srcs[0]
322        w.write('struct mi_value {} = ', self.c_name)
323        if addr.zone == 'cpu':
324            w.write('mi_mem{}(anv_address_from_u64({}));\n',
325                    self.bit_size, addr.c_cpu_val())
326        else:
327            assert self.bit_size == 64
328            w.write('mi_load_mem64_offset(&b, anv_address_from_u64(0), {});\n',
329                    addr.c_gpu_val())
330        self.write_c_refs(w)
331
332class StoreMem(Statement):
333    def __init__(self, mk, bit_size, addr, src):
334        super().__init__([mk.load_value(addr), mk.load_value(src)])
335        self.bit_size = bit_size
336
337    def write_c(self, w):
338        addr, data = tuple(self.srcs)
339        if addr.zone == 'cpu':
340            w.write('mi_store(&b, mi_mem{}(anv_address_from_u64({})), {});\n',
341                    self.bit_size, addr.c_cpu_val(), data.c_gpu_val())
342        else:
343            assert self.bit_size == 64
344            w.write('mi_store_mem64_offset(&b, anv_address_from_u64(0), {}, {});\n',
345                    addr.c_gpu_val(), data.c_gpu_val())
346
347class GoTo(Statement):
348    def __init__(self, mk, target_id, cond=None, invert=False):
349        cond = [mk.load_value(cond)] if cond is not None else []
350        super().__init__(cond)
351        self.target_id = target_id
352        self.invert = invert
353        self.mk = mk
354
355    def write_c(self, w):
356        # Now that we've parsed the entire metakernel, we can look up the
357        # actual target from the id
358        target = self.mk.get_goto_target(self.target_id)
359
360        if self.srcs:
361            cond = self.srcs[0]
362            if self.invert:
363                w.write('mi_goto_if(&b, mi_inot(&b, {}), &{});\n', cond.c_gpu_val(), target.c_name)
364            else:
365                w.write('mi_goto_if(&b, {}, &{});\n', cond.c_gpu_val(), target.c_name)
366        else:
367            w.write('mi_goto(&b, &{});\n', target.c_name)
368
369class GoToTarget(Statement):
370    def __init__(self, mk, name):
371        super().__init__()
372        self.name = name
373        self.c_name = '_goto_target_' + name
374        self.goto_tokens = []
375
376        mk = mk.add_goto_target(self)
377
378    def write_decl(self, w):
379        w.write('struct mi_goto_target {} = MI_GOTO_TARGET_INIT;\n',
380                self.c_name)
381
382    def write_c(self, w):
383        w.write('mi_goto_target(&b, &{});\n', self.c_name)
384
385class Dispatch(Statement):
386    def __init__(self, mk, kernel, group_size, args, postsync):
387        if group_size is None:
388            srcs = [mk.scope.get_def('DISPATCHDIM_{}'.format(d)) for d in 'XYZ']
389        else:
390            srcs = [mk.load_value(s) for s in group_size]
391        srcs += [mk.load_value(a) for a in args]
392        super().__init__(srcs)
393        self.kernel = mk.m.kernels[kernel]
394        self.indirect = group_size is None
395        self.postsync = postsync
396
397    def write_c(self, w):
398        w.write('{\n')
399        w.push_indent()
400
401        group_size = self.srcs[:3]
402        args = self.srcs[3:]
403        if not self.indirect:
404            w.write('const uint32_t _group_size[3] = {{ {}, {}, {} }};\n',
405                    *[s.c_cpu_val() for s in group_size])
406            gs = '_group_size'
407        else:
408            gs = 'NULL'
409
410        w.write('const struct anv_kernel_arg _args[] = {\n')
411        w.push_indent()
412        for arg in args:
413            w.write('{{ .u64 = {} }},\n', arg.c_cpu_val())
414        w.pop_indent()
415        w.write('};\n')
416
417        w.write('genX(grl_dispatch)(cmd_buffer, {},\n', self.kernel.c_name)
418        w.write('                   {}, ARRAY_SIZE(_args), _args);\n', gs)
419        w.pop_indent()
420        w.write('}\n')
421
422class SemWait(Statement):
423    def __init__(self, scope, wait):
424        super().__init__()
425        self.wait = wait
426
427class Control(Statement):
428    def __init__(self, scope, wait):
429        super().__init__()
430        self.wait = wait
431
432    def write_c(self, w):
433        w.write('cmd_buffer->state.pending_pipe_bits |=\n')
434        w.write('    ANV_PIPE_CS_STALL_BIT |\n')
435        w.write('    ANV_PIPE_DATA_CACHE_FLUSH_BIT |\n')
436        w.write('    ANV_PIPE_UNTYPED_DATAPORT_CACHE_FLUSH_BIT;\n')
437        w.write('genX(cmd_buffer_apply_pipe_flushes)(cmd_buffer);\n')
438
439TYPE_REMAPS = {
440    'dword' : 'uint32_t',
441    'qword' : 'uint64_t',
442}
443
444class Module(object):
445    def __init__(self, grl_dir, elems):
446        assert isinstance(elems[0], tuple)
447        assert elems[0][0] == 'module-name'
448        self.grl_dir = grl_dir
449        self.name = elems[0][1]
450        self.kernels = {}
451        self.structs = {}
452        self.constants = []
453        self.metakernels = []
454        self.regs = {}
455
456        scope = Scope(self, None, None)
457        for e in elems[1:]:
458            if e[0] == 'kernel':
459                k = Kernel(self, *e[1:])
460                assert k.name not in self.kernels
461                self.kernels[k.name] = k
462            elif e[0] == 'kernel-module':
463                m = KernelModule(self, *e[1:])
464                for k in m.kernels:
465                    assert k.name not in self.kernels
466                    self.kernels[k.name] = k
467            elif e[0] == 'struct':
468                s = Struct(self, *e[1:])
469                assert s.name not in self.kernels
470                self.structs[s.name] = s
471            elif e[0] == 'named-constant':
472                c = NamedConstant(*e[1:])
473                scope.add_def(c)
474                self.constants.append(c)
475            elif e[0] == 'meta-kernel':
476                mk = MetaKernel(self, scope, *e[1:])
477                self.metakernels.append(mk)
478            elif e[0] == 'import':
479                assert e[2] == 'struct'
480                self.import_struct(e[1], e[3])
481            else:
482                assert False, 'Invalid module-level token: {}'.format(t[0])
483
484    def import_struct(self, filename, struct_name):
485        elems = parse_grl_file(os.path.join(self.grl_dir, filename), [])
486        assert elems
487        for e in elems[1:]:
488            if e[0] == 'struct' and e[1] == struct_name:
489                s = Struct(self, *e[1:])
490                assert s.name not in self.kernels
491                self.structs[s.name] = s
492                return
493        assert False, "Struct {0} not found in {1}".format(struct_name, filename)
494
495    def get_type(self, name):
496        if name in self.structs:
497            return self.structs[name]
498        return BasicType(TYPE_REMAPS.get(name, name))
499
500    def get_fixed_gpr(self, num):
501        assert isinstance(num, int)
502        if num in self.regs:
503            return self.regs[num]
504
505        reg = FixedGPR(num)
506        self.regs[num] = reg
507        return reg
508
509    def optimize(self):
510        progress = True
511        while progress:
512            progress = False
513
514            # Copy Propagation
515            for mk in self.metakernels:
516                if mk.opt_copy_prop():
517                    progress = True
518
519            # Dead Code Elimination
520            for r in self.regs.values():
521                r.live = False
522            for c in self.constants:
523                c.live = False
524            for mk in self.metakernels:
525                mk.opt_dead_code1()
526            for mk in self.metakernels:
527                if mk.opt_dead_code2():
528                    progress = True
529            for n in list(self.regs.keys()):
530                if not self.regs[n].live:
531                    del self.regs[n]
532                    progress = True
533            self.constants = [c for c in self.constants if c.live]
534
535    def compact_regs(self):
536        old_regs = self.regs
537        self.regs = {}
538        for i, reg in enumerate(old_regs.values()):
539            reg.num = i
540            self.regs[i] = reg
541
542    def write_h(self, w):
543        for s in self.structs.values():
544            s.write_h(w)
545        for mk in self.metakernels:
546            mk.write_h(w)
547
548    def write_c(self, w):
549        for c in self.constants:
550            c.write_c(w)
551        for mk in self.metakernels:
552            mk.write_c(w)
553
554class Kernel(object):
555    def __init__(self, m, name, ann):
556        self.name = name
557        self.source_file = ann['source']
558        self.kernel_name = self.source_file.replace('/', '_')[:-3].upper()
559        self.entrypoint = ann['kernelFunction']
560
561        assert self.source_file.endswith('.cl')
562        self.c_name = '_'.join([
563            'GRL_CL_KERNEL',
564            self.kernel_name,
565            self.entrypoint.upper(),
566        ])
567
568class KernelModule(object):
569    def __init__(self, m, name, source, kernels):
570        self.name = name
571        self.kernels = []
572        self.libraries = []
573
574        for k in kernels:
575            if k[0] == 'kernel':
576                k[2]['source'] = source
577                self.kernels.append(Kernel(m, *k[1:]))
578            elif k[0] == 'library':
579                # Skip this for now.
580                pass
581
582class BasicType(object):
583    def __init__(self, name):
584        self.name = name
585        self.c_name = name
586
587class Struct(object):
588    def __init__(self, m, name, fields, align):
589        assert align == 0
590        self.name = name
591        self.c_name = 'struct ' + '_'.join(['grl', m.name, self.name])
592        self.fields = [(m.get_type(t), n) for t, n in fields]
593
594    def write_h(self, w):
595        w.write('{} {{\n', self.c_name)
596        w.push_indent()
597        for f in self.fields:
598            w.write('{} {};\n', f[0].c_name, f[1])
599        w.pop_indent()
600        w.write('};\n')
601
602class NamedConstant(Value):
603    def __init__(self, name, value):
604        super().__init__(name, 'cpu')
605        self.name = name
606        self.value = Constant(value)
607        self.written = False
608
609    def set_module(self, m):
610        pass
611
612    def write_c(self, w):
613        if self.written:
614            return
615        w.write('static const uint64_t {} = {};\n',
616                self.name, self.value.c_val())
617        self.written = True
618
619class MetaKernelParameter(Value):
620    def __init__(self, mk, type, name):
621        super().__init__(name, 'cpu')
622        self.type = mk.m.get_type(type)
623
624class MetaKernel(object):
625    def __init__(self, m, m_scope, name, params, ann, statements):
626        self.m = m
627        self.name = name
628        self.c_name = '_'.join(['grl', m.name, self.name])
629        self.goto_targets = {}
630        self.num_tmps = 0
631
632        mk_scope = Scope(m, self, m_scope)
633
634        self.params = [MetaKernelParameter(self, *p) for p in params]
635        for p in self.params:
636            mk_scope.add_def(p)
637
638        mk_scope.add_def(GroupSizeRegister(0), name='DISPATCHDIM_X')
639        mk_scope.add_def(GroupSizeRegister(1), name='DISPATCHDIM_Y')
640        mk_scope.add_def(GroupSizeRegister(2), name='DISPATCHDIM_Z')
641
642        self.statements = []
643        self.parse_stmt(mk_scope, statements)
644        self.scope = None
645
646    def get_tmp(self):
647        tmpN = '_tmp{}'.format(self.num_tmps)
648        self.num_tmps += 1
649        return tmpN
650
651    def add_stmt(self, stmt):
652        self.statements.append(stmt)
653        return stmt
654
655    def parse_value(self, v):
656        if isinstance(v, Value):
657            return v
658        elif isinstance(v, str):
659            if re.match(r'REG\d+', v):
660                return self.m.get_fixed_gpr(int(v[3:]))
661            else:
662                return self.scope.get_def(v)
663        elif isinstance(v, int):
664            return Constant(v)
665        elif isinstance(v, tuple):
666            if v[0] == 'member':
667                return Member(self.parse_value(v[1]), v[2])
668            elif v[0] == 'offsetof':
669                return OffsetOf(self, v[1])
670            else:
671                op = v[0]
672                srcs = [self.parse_value(s) for s in v[1:]]
673                return self.add_stmt(Expression(self, op, *srcs))
674        else:
675            assert False, 'Invalid value: {}'.format(v[0])
676
677    def load_value(self, v):
678        v = self.parse_value(v)
679        if isinstance(v, Member) and v.zone == 'gpu':
680            v = self.add_stmt(Half(v.value, v.member))
681        return v
682
683    def parse_stmt(self, scope, s):
684        self.scope = scope
685        if isinstance(s, list):
686            subscope = Scope(self.m, self, scope)
687            for stmt in s:
688                self.parse_stmt(subscope, stmt)
689        elif s[0] == 'define':
690            scope.add_def(self.parse_value(s[2]), name=s[1])
691        elif s[0] == 'assign':
692            self.add_stmt(StoreReg(self, *s[1:]))
693        elif s[0] == 'dispatch':
694            self.add_stmt(Dispatch(self, *s[1:]))
695        elif s[0] == 'load-dword':
696            v = self.add_stmt(LoadMem(self, 32, s[2]))
697            self.add_stmt(StoreReg(self, s[1], v))
698        elif s[0] == 'load-qword':
699            v = self.add_stmt(LoadMem(self, 64, s[2]))
700            self.add_stmt(StoreReg(self, s[1], v))
701        elif s[0] == 'store-dword':
702            self.add_stmt(StoreMem(self, 32, *s[1:]))
703        elif s[0] == 'store-qword':
704            self.add_stmt(StoreMem(self, 64, *s[1:]))
705        elif s[0] == 'goto':
706            self.add_stmt(GoTo(self, s[1]))
707        elif s[0] == 'goto-if':
708            self.add_stmt(GoTo(self, s[1], s[2]))
709        elif s[0] == 'goto-if-not':
710            self.add_stmt(GoTo(self, s[1], s[2], invert=True))
711        elif s[0] == 'label':
712            self.add_stmt(GoToTarget(self, s[1]))
713        elif s[0] == 'control':
714            self.add_stmt(Control(self, s[1]))
715        elif s[0] == 'sem-wait-while':
716            self.add_stmt(Control(self, s[1]))
717        else:
718            assert False, 'Invalid statement: {}'.format(s[0])
719
720    def add_goto_target(self, t):
721        assert t.name not in self.goto_targets
722        self.goto_targets[t.name] = t
723
724    def get_goto_target(self, name):
725        return self.goto_targets[name]
726
727    def opt_copy_prop(self):
728        progress = False
729        copies = {}
730        for stmt in self.statements:
731            for i in range(len(stmt.srcs)):
732                src = stmt.srcs[i]
733                if isinstance(src, FixedGPR) and src.num in copies:
734                    stmt.srcs[i] = copies[src.num]
735                    progress = True
736
737            if isinstance(stmt, StoreReg):
738                reg = stmt.reg
739                if isinstance(reg, Member):
740                    reg = reg.value
741
742                if isinstance(reg, FixedGPR):
743                    copies.pop(reg.num, None)
744                    if not stmt.srcs[0].is_reg():
745                        copies[reg.num] = stmt.srcs[0]
746            elif isinstance(stmt, (GoTo, GoToTarget)):
747                copies = {}
748
749        return progress
750
751    def opt_dead_code1(self):
752        for stmt in self.statements:
753            # Mark every register which is read as live
754            for src in stmt.srcs:
755                if isinstance(src, Register):
756                    src.live = True
757
758            # Initialize every SSA statement to dead
759            if isinstance(stmt, SSAStatement):
760                stmt.live = False
761
762    def opt_dead_code2(self):
763        def yield_live(statements):
764            gprs_read = set(self.m.regs.keys())
765            for stmt in statements:
766                if isinstance(stmt, SSAStatement):
767                    if not stmt.live:
768                        continue
769                elif isinstance(stmt, StoreReg):
770                    reg = stmt.reg
771                    if isinstance(reg, Member):
772                        reg = reg.value
773
774                    if not stmt.reg.live:
775                        continue
776
777                    if isinstance(reg, FixedGPR):
778                        if reg.num in gprs_read:
779                            gprs_read.remove(reg.num)
780                        else:
781                            continue
782                elif isinstance(stmt, (GoTo, GoToTarget)):
783                    gprs_read = set(self.m.regs.keys())
784
785                for src in stmt.srcs:
786                    src.live = True
787                    if isinstance(src, FixedGPR):
788                        gprs_read.add(src.num)
789                yield stmt
790
791        old_stmt_list = self.statements
792        old_stmt_list.reverse()
793        self.statements = list(yield_live(old_stmt_list))
794        self.statements.reverse()
795        return len(self.statements) != len(old_stmt_list)
796
797    def count_ssa_value_uses(self):
798        for stmt in self.statements:
799            if isinstance(stmt, SSAStatement):
800                stmt.uses = 0
801
802            for src in stmt.srcs:
803                if isinstance(src, SSAStatement):
804                    src.uses += 1
805
806    def write_h(self, w):
807        w.write('void\n')
808        w.write('genX({})(\n', self.c_name)
809        w.push_indent()
810        w.write('struct anv_cmd_buffer *cmd_buffer')
811        for p in self.params:
812            w.write(',\n{} {}', p.type.c_name, p.name)
813        w.write(');\n')
814        w.pop_indent()
815
816    def write_c(self, w):
817        w.write('void\n')
818        w.write('genX({})(\n', self.c_name)
819        w.push_indent()
820        w.write('struct anv_cmd_buffer *cmd_buffer')
821        for p in self.params:
822            w.write(',\n{} {}', p.type.c_name, p.name)
823        w.write(')\n')
824        w.pop_indent()
825        w.write('{\n')
826        w.push_indent()
827
828        w.write('struct mi_builder b;\n')
829        w.write('mi_builder_init(&b, cmd_buffer->device->info, &cmd_buffer->batch);\n')
830        w.write('/* TODO: use anv_mocs? */\n');
831        w.write('const uint32_t mocs = isl_mocs(&cmd_buffer->device->isl_dev, 0, false);\n');
832        w.write('mi_builder_set_mocs(&b, mocs);\n');
833        w.write('\n')
834
835        for r in self.m.regs.values():
836            r.write_c(w)
837        w.write('\n')
838
839        for t in self.goto_targets.values():
840            t.write_decl(w)
841        w.write('\n')
842
843        self.count_ssa_value_uses()
844        for s in self.statements:
845            s.write_c(w)
846
847        w.pop_indent()
848
849        w.write('}\n')
850
851HEADER_PROLOGUE = COPYRIGHT + '''
852#include "anv_private.h"
853#include "grl/genX_grl.h"
854
855#ifndef {0}
856#define {0}
857
858#ifdef __cplusplus
859extern "C" {{
860#endif
861
862'''
863
864HEADER_EPILOGUE = '''
865#ifdef __cplusplus
866}}
867#endif
868
869#endif /* {0} */
870'''
871
872C_PROLOGUE = COPYRIGHT + '''
873#include "{0}"
874
875#include "genxml/gen_macros.h"
876#include "genxml/genX_pack.h"
877#include "genxml/genX_rt_pack.h"
878
879#include "genX_mi_builder.h"
880
881#define MI_PREDICATE_RESULT mi_reg32(0x2418)
882#define DISPATCHDIM_X mi_reg32(0x2500)
883#define DISPATCHDIM_Y mi_reg32(0x2504)
884#define DISPATCHDIM_Z mi_reg32(0x2508)
885'''
886
887def parse_libraries(filenames):
888    libraries = {}
889    for fname in filenames:
890        lib_package = parse_grl_file(fname, [])
891        for lib in lib_package:
892            assert lib[0] == 'library'
893            # Add the directory of the library so that CL files can be found.
894            lib[2].append(('path', os.path.dirname(fname)))
895            libraries[lib[1]] = lib
896    return libraries
897
898def main():
899    argparser = argparse.ArgumentParser()
900    argparser.add_argument('--out-c', help='Output C file')
901    argparser.add_argument('--out-h', help='Output C file')
902    argparser.add_argument('--library', dest='libraries', action='append',
903                           default=[], help='Libraries to include')
904    argparser.add_argument('grl', help="Input  file")
905    args = argparser.parse_args()
906
907    grl_dir = os.path.dirname(args.grl)
908
909    libraries = parse_libraries(args.libraries)
910
911    ir = parse_grl_file(args.grl, libraries)
912
913    m = Module(grl_dir, ir)
914    m.optimize()
915    m.compact_regs()
916
917    with open(args.out_h, 'w') as f:
918        guard = os.path.splitext(os.path.basename(args.out_h))[0].upper()
919        w = Writer(f)
920        w.write(HEADER_PROLOGUE, guard)
921        m.write_h(w)
922        w.write(HEADER_EPILOGUE, guard)
923
924    with open(args.out_c, 'w') as f:
925        w = Writer(f)
926        w.write(C_PROLOGUE, os.path.basename(args.out_h))
927        m.write_c(w)
928
929if __name__ == '__main__':
930    main()
931