xref: /aosp_15_r20/external/mesa3d/src/imagination/csbgen/gen_pack_header.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# encoding=utf-8
2
3# Copyright © 2022 Imagination Technologies Ltd.
4
5# based on anv driver gen_pack_header.py which is:
6# Copyright © 2016 Intel Corporation
7
8# based on v3dv driver gen_pack_header.py which is:
9# Copyright (C) 2016 Broadcom
10
11# Permission is hereby granted, free of charge, to any person obtaining a copy
12# of this software and associated documentation files (the "Software"), to deal
13# in the Software without restriction, including without limitation the rights
14# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15# copies of the Software, and to permit persons to whom the Software is
16# furnished to do so, subject to the following conditions:
17
18# The above copyright notice and this permission notice (including the next
19# paragraph) shall be included in all copies or substantial portions of the
20# Software.
21
22# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28# SOFTWARE.
29
30from __future__ import annotations
31
32import copy
33import os
34import textwrap
35import typing as t
36import xml.parsers.expat as expat
37from abc import ABC
38from ast import literal_eval
39
40
41MIT_LICENSE_COMMENT = """/*
42 * Copyright © %(copyright)s
43 *
44 * Permission is hereby granted, free of charge, to any person obtaining a copy
45 * of this software and associated documentation files (the "Software"), to deal
46 * in the Software without restriction, including without limitation the rights
47 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
48 * copies of the Software, and to permit persons to whom the Software is
49 * furnished to do so, subject to the following conditions:
50 *
51 * The above copyright notice and this permission notice (including the next
52 * paragraph) shall be included in all copies or substantial portions of the
53 * Software.
54 *
55 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
56 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
57 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
58 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
59 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
60 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
61 * SOFTWARE.
62 */"""
63
64PACK_FILE_HEADER = """%(license)s
65
66/* Enums, structures and pack functions for %(platform)s.
67 *
68 * This file has been generated, do not hand edit.
69 */
70
71#ifndef %(guard)s
72#define %(guard)s
73
74#include "csbgen/pvr_packet_helpers.h"
75
76"""
77
78
79def safe_name(name: str) -> str:
80    if not name[0].isalpha():
81        name = "_" + name
82
83    return name
84
85
86def num_from_str(num_str: str) -> int:
87    if num_str.lower().startswith("0x"):
88        return int(num_str, base=16)
89
90    if num_str.startswith("0") and len(num_str) > 1:
91        raise ValueError("Octal numbers not allowed")
92
93    return int(num_str)
94
95
96class Node(ABC):
97    __slots__ = ["parent", "name"]
98
99    parent: Node
100    name: str
101
102    def __init__(self, parent: Node, name: str, *, name_is_safe: bool = False) -> None:
103        self.parent = parent
104        if name_is_safe:
105            self.name = name
106        else:
107            self.name = safe_name(name)
108
109    @property
110    def full_name(self) -> str:
111        if self.name[0] == "_":
112            return self.parent.prefix + self.name.upper()
113
114        return self.parent.prefix + "_" + self.name.upper()
115
116    @property
117    def prefix(self) -> str:
118        return self.parent.prefix
119
120    def add(self, element: Node) -> None:
121        raise RuntimeError("Element cannot be nested in %s. Element Type: %s"
122                           % (type(self).__name__.lower(), type(element).__name__))
123
124
125class Csbgen(Node):
126    __slots__ = ["prefix_field", "filename", "_defines", "_enums", "_structs", "_streams"]
127
128    prefix_field: str
129    filename: str
130    _defines: t.List[Define]
131    _enums: t.Dict[str, Enum]
132    _structs: t.Dict[str, Struct]
133    _streams: t.Dict[str, Stream]
134
135    def __init__(self, name: str, prefix: str, filename: str) -> None:
136        super().__init__(None, name.upper())
137        self.prefix_field = safe_name(prefix.upper())
138        self.filename = filename
139
140        self._defines = []
141        self._enums = {}
142        self._structs = {}
143        self._streams = {}
144
145    @property
146    def full_name(self) -> str:
147        return self.name + "_" + self.prefix_field
148
149    @property
150    def prefix(self) -> str:
151        return self.full_name
152
153    def add(self, element: Node) -> None:
154        if isinstance(element, Enum):
155            if element.name in self._enums:
156                raise RuntimeError("Enum redefined. Enum: %s" % element.name)
157
158            self._enums[element.name] = element
159        elif isinstance(element, Struct):
160            if element.name in self._structs:
161                raise RuntimeError("Struct redefined. Struct: %s" % element.name)
162
163            self._structs[element.name] = element
164        elif isinstance(element, Stream):
165            if element.name in self._streams:
166                raise RuntimeError("Stream redefined. Stream: %s" % element.name)
167
168            self._streams[element.name] = element
169        elif isinstance(element, Define):
170            define_names = [d.full_name for d in self._defines]
171            if element.full_name in define_names:
172                raise RuntimeError("Define redefined. Define: %s" % element.full_name)
173
174            self._defines.append(element)
175        else:
176            super().add(element)
177
178    def _gen_guard(self) -> str:
179        return os.path.basename(self.filename).replace(".xml", "_h").upper()
180
181    def emit(self) -> None:
182        print(PACK_FILE_HEADER % {
183            "license": MIT_LICENSE_COMMENT % {"copyright": "2022 Imagination Technologies Ltd."},
184            "platform": self.name,
185            "guard": self._gen_guard(),
186        })
187
188        for define in self._defines:
189            define.emit()
190
191        print()
192
193        for enum in self._enums.values():
194            enum.emit()
195
196        for struct in self._structs.values():
197            struct.emit(self)
198
199        for stream in self._streams.values():
200            stream.emit(self)
201
202        print("#endif /* %s */" % self._gen_guard())
203
204    def is_known_struct(self, struct_name: str) -> bool:
205        return struct_name in self._structs.keys()
206
207    def is_known_enum(self, enum_name: str) -> bool:
208        return enum_name in self._enums.keys()
209
210    def get_enum(self, enum_name: str) -> Enum:
211        return self._enums[enum_name]
212
213    def get_struct(self, struct_name: str) -> Struct:
214        return self._structs[struct_name]
215
216
217class Enum(Node):
218    __slots__ = ["_values"]
219
220    _values: t.Dict[str, Value]
221
222    def __init__(self, parent: Node, name: str) -> None:
223        super().__init__(parent, name)
224
225        self._values = {}
226
227        self.parent.add(self)
228
229    # We override prefix so that the values will contain the enum's name too.
230    @property
231    def prefix(self) -> str:
232        return self.full_name
233
234    def get_value(self, value_name: str) -> Value:
235        return self._values[value_name]
236
237    def add(self, element: Node) -> None:
238        if not isinstance(element, Value):
239            super().add(element)
240
241        if element.name in self._values:
242            raise RuntimeError("Value is being redefined. Value: '%s'" % element.name)
243
244        if element.value in self._values.values():
245            raise RuntimeError("Ambiguous enum value detected. Value: '%s'" % element.value)
246
247        self._values[element.name] = element
248
249    def _emit_to_str(self) -> None:
250        print(textwrap.dedent("""\
251            static const char *
252            %s_to_str(const enum %s value)
253            {""") % (self.full_name, self.full_name))
254
255        print("    switch (value) {")
256        for value in self._values.values():
257            print("    case %s: return \"%s\";" % (value.full_name, value.name))
258        print("    default: return NULL;")
259        print("    }")
260
261        print("}\n")
262
263    def emit(self) -> None:
264        # This check is invalid if tags other than Value can be nested within an enum.
265        if not self._values.values():
266            raise RuntimeError("Enum definition is empty. Enum: '%s'" % self.full_name)
267
268        print("enum %s {" % self.full_name)
269        for value in self._values.values():
270            value.emit()
271        print("};\n")
272
273        self._emit_to_str()
274
275
276class Value(Node):
277    __slots__ = ["value"]
278
279    value: int
280
281    def __init__(self, parent: Node, name: str, value: int) -> None:
282        super().__init__(parent, name)
283
284        self.value = value
285
286        self.parent.add(self)
287
288    def emit(self):
289        print("    %-36s = %6d," % (self.full_name, self.value))
290
291
292class Struct(Node):
293    __slots__ = ["length", "size", "_children"]
294
295    length: int
296    size: int
297    _children: t.Dict[str, t.Union[Condition, Field]]
298
299    def __init__(self, parent: Node, name: str, length: int) -> None:
300        super().__init__(parent, name)
301
302        self.length = length
303        self.size = self.length * 32
304
305        if self.length <= 0:
306            raise ValueError("Struct length must be greater than 0. Struct: '%s'." % self.full_name)
307
308        self._children = {}
309
310        self.parent.add(self)
311
312    @property
313    def fields(self) -> t.List[Field]:
314        # TODO: Should we cache? See TODO in equivalent Condition getter.
315
316        fields = []
317        for child in self._children.values():
318            if isinstance(child, Condition):
319                fields += child.fields
320            else:
321                fields.append(child)
322
323        return fields
324
325    @property
326    def prefix(self) -> str:
327        return self.full_name
328
329    def add(self, element: Node) -> None:
330        # We don't support conditions and field having the same name.
331        if isinstance(element, Field):
332            if element.name in self._children.keys():
333                raise ValueError("Field is being redefined. Field: '%s', Struct: '%s'"
334                                 % (element.name, self.full_name))
335
336            self._children[element.name] = element
337
338        elif isinstance(element, Condition):
339            # We only save ifs, and ignore the rest. The rest will be linked to
340            # the if condition so we just need to call emit() on the if and the
341            # rest will also be emitted.
342            if element.type == "if":
343                self._children[element.name] = element
344            else:
345                if element.name not in self._children.keys():
346                    raise RuntimeError("Unknown condition: '%s'" % element.name)
347
348        else:
349            super().add(element)
350
351    def _emit_header(self, root: Csbgen) -> None:
352        default_fields = []
353        for field in (f for f in self.fields if f.default is not None):
354            if field.is_builtin_type:
355                default_fields.append("    .%-35s = %6d" % (field.name, field.default))
356            else:
357                if not root.is_known_enum(field.type):
358                    # Default values should not apply to structures
359                    raise RuntimeError(
360                        "Unknown type. Field: '%s' Type: '%s'"
361                        % (field.name, field.type)
362                    )
363
364                enum = root.get_enum(field.type)
365
366                try:
367                    value = enum.get_value(field.default)
368                except KeyError:
369                    raise ValueError("Unknown enum value. Value: '%s', Enum: '%s', Field: '%s'"
370                                     % (field.default, enum.full_name, field.name))
371
372                default_fields.append("    .%-35s = %s" % (field.name, value.full_name))
373
374        print("#define %-40s\\" % (self.full_name + "_header"))
375        print(",  \\\n".join(default_fields))
376        print("")
377
378    def _emit_helper_macros(self) -> None:
379        for field in (f for f in self.fields if f.defines):
380            print("/* Helper macros for %s */" % field.name)
381
382            for define in field.defines:
383                define.emit()
384
385            print()
386
387    def _emit_pack_function(self, root: Csbgen) -> None:
388        print(textwrap.dedent("""\
389            static inline __attribute__((always_inline)) void
390            %s_pack(__attribute__((unused)) void * restrict dst,
391                  %s__attribute__((unused)) const struct %s * restrict values)
392            {""") % (self.full_name, ' ' * len(self.full_name), self.full_name))
393
394        group = Group(0, 1, self.size, self.fields)
395        dwords, length = group.collect_dwords_and_length()
396        if length:
397            # Cast dst to make header C++ friendly
398            print("    uint32_t * restrict dw = (uint32_t * restrict) dst;")
399
400        group.emit_pack_function(root, dwords, length)
401
402        print("}\n")
403
404    def _emit_unpack_function(self, root: Csbgen) -> None:
405        print(textwrap.dedent("""\
406            static inline __attribute__((always_inline)) void
407            %s_unpack(__attribute__((unused)) const void * restrict src,
408                    %s__attribute__((unused)) struct %s * restrict values)
409            {""") % (self.full_name, ' ' * len(self.full_name), self.full_name))
410
411        group = Group(0, 1, self.size, self.fields)
412        dwords, length = group.collect_dwords_and_length()
413        if length:
414            # Cast src to make header C++ friendly
415            print("    const uint32_t * restrict dw = (const uint32_t * restrict) src;")
416
417        group.emit_unpack_function(root, dwords, length)
418
419        print("}\n")
420
421    def emit(self, root: Csbgen) -> None:
422        print("#define %-33s %6d" % (self.full_name + "_length", self.length))
423
424        self._emit_header(root)
425
426        self._emit_helper_macros()
427
428        print("struct %s {" % self.full_name)
429        for child in self._children.values():
430            child.emit(root)
431        print("};\n")
432
433        self._emit_pack_function(root)
434        self._emit_unpack_function(root)
435
436
437class Stream(Node):
438    __slots__ = ["length", "size", "_children"]
439
440    length: int
441    size: int
442    _children: t.Dict[str, t.Union[Condition, Field]]
443
444    def __init__(self, parent: Node, name: str, length: int) -> None:
445        self._children = {}
446
447        super().__init__(parent, name)
448
449    @property
450    def fields(self) -> t.List[Field]:
451        fields = []
452
453    @property
454    def prefix(self) -> str:
455        return self.full_name
456
457    def add(self, element: Node) -> None:
458        # We don't support conditions and field having the same name.
459        if isinstance(element, Field):
460            if element.name in self._children.keys():
461                raise ValueError("Field is being redefined. Field: '%s', Struct: '%s'"
462                                 % (element.name, self.full_name))
463
464            self._children[element.name] = element
465
466        elif isinstance(element, Condition):
467            # We only save ifs, and ignore the rest. The rest will be linked to
468            # the if condition so we just need to call emit() on the if and the
469            # rest will also be emitted.
470            if element.type == "if":
471                self._children[element.name] = element
472            else:
473                if element.name not in self._children.keys():
474                    raise RuntimeError("Unknown condition: '%s'" % element.name)
475
476        else:
477            super().add(element)
478
479    def _emit_header(self, root: Csbgen) -> None:
480        pass
481
482    def _emit_helper_macros(self) -> None:
483        pass
484
485    def _emit_pack_function(self, root: Csbgen) -> None:
486        pass
487
488    def _emit_unpack_function(self, root: Csbgen) -> None:
489        pass
490
491    def emit(self, root: Csbgen) -> None:
492        pass
493
494class Field(Node):
495    __slots__ = ["start", "end", "type", "default", "shift", "_defines"]
496
497    start: int
498    end: int
499    type: str
500    default: t.Optional[t.Union[str, int]]
501    shift: t.Optional[int]
502    _defines: t.Dict[str, Define]
503
504    def __init__(self, parent: Node, name: str, start: int, end: int, ty: str, *,
505                 default: t.Optional[str] = None, shift: t.Optional[int] = None) -> None:
506        super().__init__(parent, name)
507
508        self.start = start
509        self.end = end
510        self.type = ty
511
512        self._defines = {}
513
514        self.parent.add(self)
515
516        if self.start > self.end:
517            raise ValueError("Start cannot be after end. Start: %d, End: %d, Field: '%s'"
518                             % (self.start, self.end, self.name))
519
520        if self.type == "bool" and self.end != self.start:
521            raise ValueError("Bool field can only be 1 bit long. Field '%s'" % self.name)
522
523        if default is not None:
524            if not self.is_builtin_type:
525                # Assuming it's an enum type.
526                self.default = safe_name(default)
527            else:
528                self.default = num_from_str(default)
529        else:
530            self.default = None
531
532        if shift is not None:
533            if self.type != "address":
534                raise RuntimeError("Only address fields can have a shift attribute. Field: '%s'" % self.name)
535
536            self.shift = int(shift)
537
538            Define(self, "ALIGNMENT", 2**self.shift)
539        else:
540            if self.type == "address":
541                raise RuntimeError("Field of address type requires a shift attribute. Field '%s'" % self.name)
542
543            self.shift = None
544
545    @property
546    def defines(self) -> t.Iterator[Define]:
547        return self._defines.values()
548
549    # We override prefix so that the defines will contain the field's name too.
550    @property
551    def prefix(self) -> str:
552        return self.full_name
553
554    @property
555    def is_builtin_type(self) -> bool:
556        builtins = {"address", "bool", "float", "mbo", "offset", "int", "uint"}
557        return self.type in builtins
558
559    def _get_c_type(self, root: Csbgen) -> str:
560        if self.type == "address":
561            return "__pvr_address_type"
562        elif self.type == "bool":
563            return "bool"
564        elif self.type == "float":
565            return "float"
566        elif self.type == "offset":
567            return "uint64_t"
568        elif self.type == "int":
569            return "int32_t"
570        elif self.type == "uint":
571            if self.end - self.start <= 32:
572                return "uint32_t"
573            elif self.end - self.start <= 64:
574                return "uint64_t"
575
576            raise RuntimeError("No known C type found to hold %d bit sized value. Field: '%s'"
577                               % (self.end - self.start, self.name))
578        elif self.type == "uint_array":
579            return "uint8_t"
580        elif root.is_known_struct(self.type):
581            return "struct " + self.type
582        elif root.is_known_enum(self.type):
583            return "enum " + root.get_enum(self.type).full_name
584        raise RuntimeError("Unknown type. Type: '%s', Field: '%s'" % (self.type, self.name))
585
586    def add(self, element: Node) -> None:
587        if self.type == "mbo":
588            raise RuntimeError("No element can be nested in an mbo field. Element Type: %s, Field: %s"
589                               % (type(element).__name__, self.name))
590
591        if isinstance(element, Define):
592            if element.name in self._defines:
593                raise RuntimeError("Duplicate define. Define: '%s'" % element.name)
594
595            self._defines[element.name] = element
596        else:
597            super().add(element)
598
599    def emit(self, root: Csbgen) -> None:
600        if self.type == "mbo":
601            return
602
603        if self.type == "uint_array":
604            print("    %-36s %s[%u];" % (self._get_c_type(root), self.name, (self.end - self.start) / 8))
605        else:
606            print("    %-36s %s;" % (self._get_c_type(root), self.name))
607
608
609class Define(Node):
610    __slots__ = ["value"]
611
612    value: int
613
614    def __init__(self, parent: Node, name: str, value: int) -> None:
615        super().__init__(parent, name)
616
617        self.value = value
618
619        self.parent.add(self)
620
621    def emit(self) -> None:
622        print("#define %-40s %d" % (self.full_name, self.value))
623
624
625class Condition(Node):
626    __slots__ = ["type", "_children", "_child_branch"]
627
628    type: str
629    _children: t.Dict[str, t.Union[Condition, Field]]
630    _child_branch: t.Optional[Condition]
631
632    def __init__(self, parent: Node, name: str, ty: str) -> None:
633        super().__init__(parent, name, name_is_safe=True)
634
635        self.type = ty
636        if not Condition._is_valid_type(self.type):
637            raise RuntimeError("Unknown type: '%s'" % self.name)
638
639        self._children = {}
640
641        # This is the link to the next branch for the if statement so either
642        # elif, else, or endif. They themselves will also have a link to the
643        # next branch up until endif which terminates the chain.
644        self._child_branch = None
645
646        self.parent.add(self)
647
648    @property
649    def fields(self) -> t.List[Field]:
650        # TODO: Should we use some kind of state to indicate the all of the
651        # child nodes have been added and then cache the fields in here on the
652        # first call so that we don't have to traverse them again per each call?
653        # The state could be changed wither when we reach the endif and pop from
654        # the context, or when we start emitting.
655
656        fields = []
657
658        for child in self._children.values():
659            if isinstance(child, Condition):
660                fields += child.fields
661            else:
662                fields.append(child)
663
664        if self._child_branch is not None:
665            fields += self._child_branch.fields
666
667        return fields
668
669    @staticmethod
670    def _is_valid_type(ty: str) -> bool:
671        types = {"if", "elif", "else", "endif"}
672        return ty in types
673
674    def _is_compatible_child_branch(self, branch):
675        types = ["if", "elif", "else", "endif"]
676        idx = types.index(self.type)
677        return (branch.type in types[idx + 1:] or
678                self.type == "elif" and branch.type == "elif")
679
680    def _add_branch(self, branch: Condition) -> None:
681        if branch.type == "elif" and branch.name == self.name:
682            raise RuntimeError("Elif branch cannot have same check as previous branch. Check: '%s'" % branch.name)
683
684        if not self._is_compatible_child_branch(branch):
685            raise RuntimeError("Invalid branch. Check: '%s', Type: '%s'" % (branch.name, branch.type))
686
687        self._child_branch = branch
688
689    # Returns the name of the if condition. This is used for elif branches since
690    # they have a different name than the if condition thus we have to traverse
691    # the chain of branches.
692    # This is used to discriminate nested if conditions from branches since
693    # branches like 'endif' and 'else' will have the same name as the 'if' (the
694    # elif is an exception) while nested conditions will have different names.
695    #
696    # TODO: Redo this to improve speed? Would caching this be helpful? We could
697    # just save the name of the if instead of having to walk towards it whenever
698    # a new condition is being added.
699    def _top_branch_name(self) -> str:
700        if self.type == "if":
701            return self.name
702
703        # If we're not an 'if' condition, our parent must be another condition.
704        assert isinstance(self.parent, Condition)
705        return self.parent._top_branch_name()
706
707    def add(self, element: Node) -> None:
708        if isinstance(element, Field):
709            if element.name in self._children.keys():
710                raise ValueError("Duplicate field. Field: '%s'" % element.name)
711
712            self._children[element.name] = element
713        elif isinstance(element, Condition):
714            if element.type == "elif" or self._top_branch_name() == element.name:
715                self._add_branch(element)
716            else:
717                if element.type != "if":
718                    raise RuntimeError("Branch of an unopened if condition. Check: '%s', Type: '%s'."
719                                       % (element.name, element.type))
720
721                # This is a nested condition and we made sure that the name
722                # doesn't match _top_branch_name() so we can recognize the else
723                # and endif.
724                # We recognized the elif by its type however its name differs
725                # from the if condition thus when we add an if condition with
726                # the same name as the elif nested in it, the _top_branch_name()
727                # check doesn't hold true as the name matched the elif and not
728                # the if statement which the elif was a branch of, thus the
729                # nested if condition is not recognized as an invalid branch of
730                # the outer if statement.
731                #   Sample:
732                #   <condition type="if" check="ROGUEXE"/>
733                #       <condition type="elif" check="COMPUTE"/>
734                #           <condition type="if" check="COMPUTE"/>
735                #           <condition type="endif" check="COMPUTE"/>
736                #       <condition type="endif" check="COMPUTE"/>
737                #   <condition type="endif" check="ROGUEXE"/>
738                #
739                # We fix this by checking the if condition name against its
740                # parent.
741                if element.name == self.name:
742                    raise RuntimeError("Invalid if condition. Check: '%s'" % element.name)
743
744                self._children[element.name] = element
745        else:
746            super().add(element)
747
748    def emit(self, root: Csbgen) -> None:
749        if self.type == "if":
750            print("/* if %s is supported use: */" % self.name)
751        elif self.type == "elif":
752            print("/* else if %s is supported use: */" % self.name)
753        elif self.type == "else":
754            print("/* else %s is not-supported use: */" % self.name)
755        elif self.type == "endif":
756            print("/* endif %s */" % self.name)
757            return
758        else:
759            raise RuntimeError("Unknown condition type. Implementation error.")
760
761        for child in self._children.values():
762            child.emit(root)
763
764        self._child_branch.emit(root)
765
766
767class Group:
768    __slots__ = ["start", "count", "size", "fields"]
769
770    start: int
771    count: int
772    size: int
773    fields: t.List[Field]
774
775    def __init__(self, start: int, count: int, size: int, fields) -> None:
776        self.start = start
777        self.count = count
778        self.size = size
779        self.fields = fields
780
781    class DWord:
782        __slots__ = ["size", "fields", "addresses"]
783
784        size: int
785        fields: t.List[Field]
786        addresses: t.List[Field]
787
788        def __init__(self) -> None:
789            self.size = 32
790            self.fields = []
791            self.addresses = []
792
793    def collect_dwords(self, dwords: t.Dict[int, Group.DWord], start: int) -> None:
794        for field in self.fields:
795            index = (start + field.start) // 32
796            if index not in dwords:
797                dwords[index] = self.DWord()
798
799            clone = copy.copy(field)
800            clone.start = clone.start + start
801            clone.end = clone.end + start
802            dwords[index].fields.append(clone)
803
804            if field.type == "address":
805                # assert dwords[index].address == None
806                dwords[index].addresses.append(clone)
807
808            # Coalesce all the dwords covered by this field. The two cases we
809            # handle are where multiple fields are in a 64 bit word (typically
810            # and address and a few bits) or where a single struct field
811            # completely covers multiple dwords.
812            while index < (start + field.end) // 32:
813                if index + 1 in dwords and not dwords[index] == dwords[index + 1]:
814                    dwords[index].fields.extend(dwords[index + 1].fields)
815                    dwords[index].addresses.extend(dwords[index + 1].addresses)
816                dwords[index].size = 64
817                dwords[index + 1] = dwords[index]
818                index = index + 1
819
820    def collect_dwords_and_length(self) -> t.Tuple[t.Dict[int, Group.DWord], int]:
821        dwords = {}
822        self.collect_dwords(dwords, 0)
823
824        # Determine number of dwords in this group. If we have a size, use
825        # that, since that'll account for MBZ dwords at the end of a group
826        # (like dword 8 on BDW+ 3DSTATE_HS). Otherwise, use the largest dword
827        # index we've seen plus one.
828        if self.size > 0:
829            length = self.size // 32
830        elif dwords:
831            length = max(dwords.keys()) + 1
832        else:
833            length = 0
834
835        return dwords, length
836
837    def emit_pack_function(self, root: Csbgen, dwords: t.Dict[int, Group.DWord], length: int) -> None:
838        for index in range(length):
839            # Handle MBZ dwords
840            if index not in dwords:
841                print("")
842                print("    dw[%d] = 0;" % index)
843                continue
844
845            # For 64 bit dwords, we aliased the two dword entries in the dword
846            # dict it occupies. Now that we're emitting the pack function,
847            # skip the duplicate entries.
848            dw = dwords[index]
849            if index > 0 and index - 1 in dwords and dw == dwords[index - 1]:
850                continue
851
852            # Special case: only one field and it's a struct at the beginning
853            # of the dword. In this case we pack directly into the
854            # destination. This is the only way we handle embedded structs
855            # larger than 32 bits.
856            if len(dw.fields) == 1:
857                field = dw.fields[0]
858                if root.is_known_struct(field.type) and field.start % 32 == 0:
859                    print("")
860                    print("    %s_pack(data, &dw[%d], &values->%s);"
861                          % (self.parser.gen_prefix(safe_name(field.type)), index, field.name))
862                    continue
863
864            # Pack any fields of struct type first so we have integer values
865            # to the dword for those fields.
866            field_index = 0
867            for field in dw.fields:
868                if root.is_known_struct(field.type):
869                    print("")
870                    print("    uint32_t v%d_%d;" % (index, field_index))
871                    print("    %s_pack(data, &v%d_%d, &values->%s);"
872                          % (self.parser.gen_prefix(safe_name(field.type)), index, field_index, field.name))
873                    field_index = field_index + 1
874
875            print("")
876            dword_start = index * 32
877            address_count = len(dw.addresses)
878
879            if dw.size == 32 and not dw.addresses:
880                v = None
881                print("    dw[%d] =" % index)
882            elif len(dw.fields) > address_count:
883                v = "v%d" % index
884                print("    const uint%d_t %s =" % (dw.size, v))
885            else:
886                v = "0"
887
888            field_index = 0
889            non_address_fields = []
890            for field in dw.fields:
891                if field.type == "mbo":
892                    non_address_fields.append("__pvr_mbo(%d, %d)"
893                                              % (field.start - dword_start, field.end - dword_start))
894                elif field.type == "address":
895                    pass
896                elif field.type == "uint":
897                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
898                                              % (field.name, field.start - dword_start, field.end - dword_start))
899                elif root.is_known_enum(field.type):
900                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
901                                              % (field.name, field.start - dword_start, field.end - dword_start))
902                elif field.type == "int":
903                    non_address_fields.append("__pvr_sint(values->%s, %d, %d)"
904                                              % (field.name, field.start - dword_start, field.end - dword_start))
905                elif field.type == "bool":
906                    non_address_fields.append("__pvr_uint(values->%s, %d, %d)"
907                                              % (field.name, field.start - dword_start, field.end - dword_start))
908                elif field.type == "float":
909                    non_address_fields.append("__pvr_float(values->%s)" % field.name)
910                elif field.type == "offset":
911                    non_address_fields.append("__pvr_offset(values->%s, %d, %d)"
912                                              % (field.name, field.start - dword_start, field.end - dword_start))
913                elif field.type == "uint_array":
914                    pass
915                elif field.is_struct_type():
916                    non_address_fields.append("__pvr_uint(v%d_%d, %d, %d)"
917                                              % (index, field_index, field.start - dword_start,
918                                                 field.end - dword_start))
919                    field_index = field_index + 1
920                else:
921                    non_address_fields.append(
922                        "/* unhandled field %s," " type %s */\n" % (field.name, field.type)
923                    )
924
925            if non_address_fields:
926                print(" |\n".join("      " + f for f in non_address_fields) + ";")
927
928            if dw.size == 32:
929                for addr in dw.addresses:
930                    print("    dw[%d] = __pvr_address(values->%s, %d, %d, %d) | %s;"
931                          % (index, addr.name, addr.shift, addr.start - dword_start,
932                             addr.end - dword_start, v))
933                continue
934
935            v_accumulated_addr = ""
936            for i, addr in enumerate(dw.addresses):
937                v_address = "v%d_address" % i
938                v_accumulated_addr += "v%d_address" % i
939                print("    const uint64_t %s =" % v_address)
940                print("      __pvr_address(values->%s, %d, %d, %d);"
941                      % (addr.name, addr.shift, addr.start - dword_start, addr.end - dword_start))
942                if i < (address_count - 1):
943                    v_accumulated_addr += " |\n            "
944
945            if dw.addresses:
946                if len(dw.fields) > address_count:
947                    print("    dw[%d] = %s | %s;" % (index, v_accumulated_addr, v))
948                    print("    dw[%d] = (%s >> 32) | (%s >> 32);" % (index + 1, v_accumulated_addr, v))
949                    continue
950                else:
951                    v = v_accumulated_addr
952
953            print("    dw[%d] = %s;" % (index, v))
954            print("    dw[%d] = %s >> 32;" % (index + 1, v))
955
956    def emit_unpack_function(self, root: Csbgen, dwords: t.Dict[int, Group.DWord], length: int) -> None:
957        for index in range(length):
958            # Ignore MBZ dwords
959            if index not in dwords:
960                continue
961
962            # For 64 bit dwords, we aliased the two dword entries in the dword
963            # dict it occupies. Now that we're emitting the unpack function,
964            # skip the duplicate entries.
965            dw = dwords[index]
966            if index > 0 and index - 1 in dwords and dw == dwords[index - 1]:
967                continue
968
969            # Special case: only one field and it's a struct at the beginning
970            # of the dword. In this case we unpack directly from the
971            # source. This is the only way we handle embedded structs
972            # larger than 32 bits.
973            if len(dw.fields) == 1:
974                field = dw.fields[0]
975                if root.is_known_struct(field.type) and field.start % 32 == 0:
976                    prefix = root.get_struct(field.type)
977                    print("")
978                    print("    %s_unpack(data, &dw[%d], &values->%s);" % (prefix, index, field.name))
979                    continue
980
981            dword_start = index * 32
982
983            if dw.size == 32:
984                v = "dw[%d]" % index
985            elif dw.size == 64:
986                v = "v%d" % index
987                print("    const uint%d_t %s = dw[%d] | ((uint64_t)dw[%d] << 32);" % (dw.size, v, index, index + 1))
988            else:
989                raise RuntimeError("Unsupported dword size %d" % dw.size)
990
991            # Unpack any fields of struct type first.
992            for field_index, field in enumerate(f for f in dw.fields if root.is_known_struct(f.type)):
993                prefix = root.get_struct(field.type).prefix
994                vname = "v%d_%d" % (index, field_index)
995                print("")
996                print("    uint32_t %s = __pvr_uint_unpack(%s, %d, %d);"
997                      % (vname, v, field.start - dword_start, field.end - dword_start))
998                print("    %s_unpack(data, &%s, &values->%s);" % (prefix, vname, field.name))
999
1000            for field in dw.fields:
1001                dword_field_start = field.start - dword_start
1002                dword_field_end = field.end - dword_start
1003
1004                if field.type == "mbo" or root.is_known_struct(field.type):
1005                    continue
1006                elif field.type == "uint" or root.is_known_enum(field.type) or field.type == "bool":
1007                    print("    values->%s = __pvr_uint_unpack(%s, %d, %d);"
1008                          % (field.name, v, dword_field_start, dword_field_end))
1009                elif field.type == "int":
1010                    print("    values->%s = __pvr_sint_unpack(%s, %d, %d);"
1011                          % (field.name, v, dword_field_start, dword_field_end))
1012                elif field.type == "float":
1013                    print("    values->%s = __pvr_float_unpack(%s);" % (field.name, v))
1014                elif field.type == "offset":
1015                    print("    values->%s = __pvr_offset_unpack(%s, %d, %d);"
1016                          % (field.name, v, dword_field_start, dword_field_end))
1017                elif field.type == "address":
1018                    print("    values->%s = __pvr_address_unpack(%s, %d, %d, %d);"
1019                          % (field.name, v, field.shift, dword_field_start, dword_field_end))
1020                else:
1021                    print("/* unhandled field %s, type %s */" % (field.name, field.type))
1022
1023
1024
1025class Parser:
1026    __slots__ = ["parser", "context", "filename"]
1027
1028    parser: expat.XMLParserType
1029    context: t.List[Node]
1030    filename: str
1031
1032    def __init__(self) -> None:
1033        self.parser = expat.ParserCreate()
1034        self.parser.StartElementHandler = self.start_element
1035        self.parser.EndElementHandler = self.end_element
1036
1037        self.context = []
1038        self.filename = ""
1039
1040    def start_element(self, name: str, attrs: t.Dict[str, str]) -> None:
1041        if name == "csbgen":
1042            if self.context:
1043                raise RuntimeError(
1044                    "Can only have 1 csbgen block and it has "
1045                    + "to contain all of the other elements."
1046                )
1047
1048            csbgen = Csbgen(attrs["name"], attrs["prefix"], self.filename)
1049            self.context.append(csbgen)
1050            return
1051
1052        parent = self.context[-1]
1053
1054        if name == "struct":
1055            struct = Struct(parent, attrs["name"], int(attrs["length"]))
1056            self.context.append(struct)
1057
1058        elif name == "stream":
1059            stream = Stream(parent, attrs["name"], int(attrs["length"]))
1060            self.context.append(stream)
1061
1062        elif name == "field":
1063            default = None
1064            if "default" in attrs.keys():
1065                default = attrs["default"]
1066
1067            shift = None
1068            if "shift" in attrs.keys():
1069                shift = attrs["shift"]
1070
1071            if "start" in attrs.keys():
1072                if ":" in str(attrs["start"]):
1073                    (word, bit) = attrs["start"].split(":")
1074                    start = (int(word) * 32) + int(bit)
1075                else:
1076                    start = int(attrs["start"])
1077            else:
1078                element = self.context[-1]
1079                if isinstance(element, Stream):
1080                    start = 0
1081                else:
1082                    raise RuntimeError("Field requires start attribute outside of stream.")
1083
1084            if "size" in attrs.keys():
1085                end = start + int(attrs["size"])
1086            else:
1087                end = int(attrs["end"])
1088
1089            field = Field(parent, name=attrs["name"], start=start, end=end, ty=attrs["type"],
1090                          default=default, shift=shift)
1091            self.context.append(field)
1092
1093        elif name == "enum":
1094            enum = Enum(parent, attrs["name"])
1095            self.context.append(enum)
1096
1097        elif name == "value":
1098            value = Value(parent, attrs["name"], int(literal_eval(attrs["value"])))
1099            self.context.append(value)
1100
1101        elif name == "define":
1102            define = Define(parent, attrs["name"], int(literal_eval(attrs["value"])))
1103            self.context.append(define)
1104
1105        elif name == "condition":
1106            condition = Condition(parent, name=attrs["check"], ty=attrs["type"])
1107
1108            # Starting with the if statement we push it in the context. For each
1109            # branch following (elif, and else) we assign the top of stack as
1110            # its parent, pop() and push the new condition. So per branch we end
1111            # up having [..., struct, condition]. We don't push an endif since
1112            # it's not supposed to have any children and it's supposed to close
1113            # the whole if statement.
1114
1115            if condition.type != "if":
1116                # Remove the parent condition from the context. We were peeking
1117                # before, now we pop().
1118                self.context.pop()
1119
1120            if condition.type == "endif":
1121                if not isinstance(parent, Condition):
1122                    raise RuntimeError("Cannot close unopened or already closed condition. Condition: '%s'"
1123                                       % condition.name)
1124            else:
1125                self.context.append(condition)
1126
1127        else:
1128            raise RuntimeError("Unknown tag: '%s'" % name)
1129
1130    def end_element(self, name: str) -> None:
1131        if name == "condition":
1132            element = self.context[-1]
1133            if not isinstance(element, Condition) and not isinstance(element, Struct):
1134                raise RuntimeError("Expected condition or struct tag to be closed.")
1135
1136            return
1137
1138        element = self.context.pop()
1139
1140        if name == "struct":
1141            if not isinstance(element, Struct):
1142                raise RuntimeError("Expected struct tag to be closed.")
1143        elif name == "stream":
1144            if not isinstance(element, Stream):
1145                raise RuntimeError("Expected stream tag to be closed.")
1146        elif name == "field":
1147            if not isinstance(element, Field):
1148                raise RuntimeError("Expected field tag to be closed.")
1149        elif name == "enum":
1150            if not isinstance(element, Enum):
1151                raise RuntimeError("Expected enum tag to be closed.")
1152        elif name == "value":
1153            if not isinstance(element, Value):
1154                raise RuntimeError("Expected value tag to be closed.")
1155        elif name == "define":
1156            if not isinstance(element, Define):
1157                raise RuntimeError("Expected define tag to be closed.")
1158        elif name == "csbgen":
1159            if not isinstance(element, Csbgen):
1160                raise RuntimeError("Expected csbgen tag to be closed.\nSome tags may have not been closed")
1161
1162            element.emit()
1163        else:
1164            raise RuntimeError("Unknown closing element: '%s'" % name)
1165
1166    def parse(self, filename: str) -> None:
1167        file = open(filename, "rb")
1168        self.filename = filename
1169        self.parser.ParseFile(file)
1170        file.close()
1171
1172
1173if __name__ == "__main__":
1174    import sys
1175
1176    if len(sys.argv) < 2:
1177        print("No input xml file specified")
1178        sys.exit(1)
1179
1180    input_file = sys.argv[1]
1181
1182    p = Parser()
1183    p.parse(input_file)
1184