xref: /aosp_15_r20/external/mesa3d/src/vulkan/runtime/vk_format_info_gen.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1COPYRIGHT=u"""
2/* Copyright © 2022 Collabora, Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23"""
24
25import argparse
26import os
27import re
28from collections import namedtuple
29import xml.etree.ElementTree as et
30
31from mako.template import Template
32
33TEMPLATE_H = Template(COPYRIGHT + """\
34/* This file generated from ${filename}, don't edit directly. */
35
36#ifndef VK_FORMAT_INFO_H
37#define VK_FORMAT_INFO_H
38
39#include <vulkan/vulkan_core.h>
40
41#ifdef __cplusplus
42extern "C" {
43#endif
44
45enum vk_format_class {
46   MESA_VK_FORMAT_CLASS_UNKNOWN,
47% for name in format_classes:
48   ${to_enum_name('MESA_VK_FORMAT_CLASS_', name)},
49% endfor
50};
51
52struct vk_format_class_info {
53   const VkFormat *formats;
54   uint32_t format_count;
55};
56
57const struct vk_format_class_info *
58vk_format_class_get_info(enum vk_format_class class);
59
60const struct vk_format_class_info *
61vk_format_get_class_info(VkFormat format);
62
63#ifdef __cplusplus
64}
65#endif
66
67#endif
68""")
69
70TEMPLATE_C = Template(COPYRIGHT + """
71/* This file generated from ${filename}, don't edit directly. */
72
73#include "${header}"
74
75#include "util/macros.h"
76
77#include "vk_format.h"
78
79struct vk_format_info {
80   enum vk_format_class class;
81};
82
83% for id, ext in extensions.items():
84static const struct vk_format_info ext${id}_format_infos[] = {
85%   for name, format in ext.formats.items():
86   [${format.offset}] = {
87      .class = ${to_enum_name('MESA_VK_FORMAT_CLASS_', format.cls)},
88   },
89%   endfor
90};
91
92% endfor
93static const struct vk_format_info *
94vk_format_get_info(VkFormat format)
95{
96   uint32_t extnumber =
97      format < 1000000000 ? 0 : (((format % 1000000000) / 1000) + 1);
98   uint32_t offset = format % 1000;
99
100   switch (extnumber) {
101% for id, ext in extensions.items():
102   case ${id}:
103      assert(offset < ARRAY_SIZE(ext${id}_format_infos));
104      return &ext${id}_format_infos[offset];
105% endfor
106   default:
107      unreachable("Invalid extension");
108   }
109}
110
111% for clsname, cls in format_classes.items():
112%   if len(cls.formats) > 0:
113static const VkFormat ${to_enum_name('MESA_VK_FORMAT_CLASS_', clsname).lower() + '_formats'}[] = {
114%     for fname in cls.formats:
115   ${fname},
116%     endfor
117%   endif
118};
119
120% endfor
121static const struct vk_format_class_info class_infos[] = {
122% for clsname, cls in format_classes.items():
123   [${to_enum_name('MESA_VK_FORMAT_CLASS_', clsname)}] = {
124%   if len(cls.formats) > 0:
125      .formats = ${to_enum_name('MESA_VK_FORMAT_CLASS_', clsname).lower() + '_formats'},
126      .format_count = ARRAY_SIZE(${to_enum_name('MESA_VK_FORMAT_CLASS_', clsname).lower() + '_formats'}),
127%   else:
128      0
129%   endif
130   },
131% endfor
132};
133
134const struct vk_format_class_info *
135vk_format_class_get_info(enum vk_format_class class)
136{
137   assert(class < ARRAY_SIZE(class_infos));
138   return &class_infos[class];
139}
140
141const struct vk_format_class_info *
142vk_format_get_class_info(VkFormat format)
143{
144    const struct vk_format_info *format_info = vk_format_get_info(format);
145    return &class_infos[format_info->class];
146}
147""")
148
149def to_enum_name(prefix, name):
150    return "%s" % prefix + re.sub('([^A-Za-z0-9_])', '_', name).upper()
151
152Format = namedtuple('Format', ['name', 'cls', 'ext', 'offset'])
153FormatClass = namedtuple('FormatClass', ['name', 'formats'])
154Extension = namedtuple('Extension', ['id', 'formats'])
155
156def get_formats(doc):
157    """Extract the formats from the registry."""
158    formats = {}
159
160    for fmt in doc.findall('./formats/format'):
161        xpath = './/enum[@name="{}"]'.format(fmt.attrib['name'])
162        enum = doc.find(xpath)
163        ext = None
164        if 'extends' in enum.attrib:
165            assert(enum.attrib['extends'] == 'VkFormat')
166            if 'extnumber' in enum.attrib:
167                ext = int(enum.attrib['extnumber'])
168            else:
169                xpath = xpath + '/..'
170                parent = doc.find(xpath)
171                while parent != None and ext == None:
172                    if parent.tag == 'extension':
173                        assert('number' in parent.attrib)
174                        ext = parent.attrib['number']
175                    xpath = xpath + '/..'
176                    parent = doc.find(xpath)
177            offset = int(enum.attrib['offset'])
178        else:
179            ext = 0
180            offset = int(enum.attrib['value'])
181
182        assert(ext != None)
183        format = Format(fmt.attrib['name'], fmt.attrib['class'], ext, offset)
184        formats[format.name] = format
185
186    return formats
187
188def get_formats_from_xml(xml_files):
189    formats = {}
190
191    for filename in xml_files:
192        doc = et.parse(filename)
193        formats.update(get_formats(doc))
194
195    return formats
196
197def main():
198    parser = argparse.ArgumentParser()
199    parser.add_argument('--out-c', required=True, help='Output C file.')
200    parser.add_argument('--out-h', required=True, help='Output H file.')
201    parser.add_argument('--xml',
202                        help='Vulkan API XML file.',
203                        required=True, action='append', dest='xml_files')
204    args = parser.parse_args()
205
206    formats = get_formats_from_xml(args.xml_files)
207    classes = {}
208    extensions = {}
209    for n, f in formats.items():
210        if f.cls not in classes:
211            classes[f.cls] = FormatClass(f.cls, {})
212        classes[f.cls].formats[f.name] = f
213        if f.ext not in extensions:
214            extensions[f.ext] = Extension(f.cls, {})
215        extensions[f.ext].formats[f.name] = f
216
217    assert os.path.dirname(args.out_c) == os.path.dirname(args.out_h)
218
219    environment = {
220        'header': os.path.basename(args.out_h),
221        'formats': formats,
222        'format_classes': classes,
223        'extensions': extensions,
224        'filename': os.path.basename(__file__),
225        'to_enum_name': to_enum_name,
226    }
227
228    try:
229        with open(args.out_h, 'w', encoding='utf-8') as f:
230            guard = os.path.basename(args.out_h).replace('.', '_').upper()
231            f.write(TEMPLATE_H.render(guard=guard, **environment))
232        with open(args.out_c, 'w', encoding='utf-8') as f:
233            f.write(TEMPLATE_C.render(**environment))
234    except Exception:
235        # In the event there's an error, this imports some helpers from mako
236        # to print a useful stack trace and prints it, then exits with
237        # status 1, if python is run with debug; otherwise it just raises
238        # the exception
239        import sys
240        from mako import exceptions
241        print(exceptions.text_error_template().render(), file=sys.stderr)
242        sys.exit(1)
243
244if __name__ == '__main__':
245    main()
246