xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/gen_vulkan_spv.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env python3
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10import argparse
11import array
12import codecs
13import copy
14import glob
15import io
16import os
17import re
18import sys
19from itertools import product
20from multiprocessing.pool import ThreadPool
21
22sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
23import subprocess
24import textwrap
25from dataclasses import dataclass
26from typing import Any, Dict, List, Optional, Set, Tuple, Union
27
28import yaml
29from yaml.constructor import ConstructorError
30from yaml.nodes import MappingNode
31
32try:
33    from yaml import CLoader as Loader
34except ImportError:
35    from yaml import Loader  # type: ignore[assignment, misc]
36
37CPP_H_NAME = "spv.h"
38CPP_SRC_NAME = "spv.cpp"
39
40# Basic configuration settings for shaders
41DEFAULT_ENV: Dict[str, Any] = {
42    "PRECISION": "highp",
43    # B is shorthand for "binding". This is used to automatically increment the
44    # layout binding index when declaring layout bindings. Note that a container
45    # type is used because integers are immutable in Python.
46    "B": [0],
47    # C is shorthand for "constant_id". This is used to automatically increment the
48    # constant_id index for specialization constants.
49    # Note that it starts at 3, as 0-2 are reserved for local workgroup size ids.
50    "C": [3],
51}
52
53# Establishes relationships between different tensor types and different GLSL types
54TYPE_MAPPINGS: Dict[str, Any] = {
55    "IMAGE_T": {
56        3: {
57            "float": "image3D",
58            "half": "image3D",
59            "int": "iimage3D",
60            "uint": "uimage3D",
61            "int8": "iimage3D",
62            "uint8": "uimage3D",
63        },
64        2: {
65            "float": "image2D",
66            "half": "image2D",
67            "int": "iimage2D",
68            "uint": "uimage2D",
69            "int8": "iimage2D",
70            "uint8": "uimage2D",
71        },
72    },
73    "SAMPLER_T": {
74        3: {
75            "float": "sampler3D",
76            "half": "sampler3D",
77            "int": "isampler3D",
78            "uint": "usampler3D",
79            "int8": "isampler3D",
80            "uint8": "usampler3D",
81        },
82        2: {
83            "float": "sampler2D",
84            "half": "sampler2D",
85            "int": "isampler2D",
86            "uint": "usampler2D",
87            "int8": "isampler2D",
88            "uint8": "usampler2D",
89        },
90    },
91    "IMAGE_FORMAT": {
92        "float": "rgba32f",
93        "half": "rgba16f",
94        "int": "rgba32i",
95        "uint": "rgba32ui",
96        "int8": "rgba8i",
97        "uint8": "rgba8ui",
98    },
99}
100
101
102def define_variable(name: str) -> str:
103    if name in locals():
104        return f"#define {name} {locals()[name]}"
105    elif name in globals():
106        return f"#define {name} {globals()[name]}"
107    else:
108        raise RuntimeError(f"{name} is not defined")
109
110
111def buffer_scalar_type(dtype: str) -> str:
112    if dtype == "half":
113        return "float16_t"
114    elif dtype[-1] == "8":
115        return dtype + "_t"
116
117    return dtype
118
119
120def buffer_gvec_type(dtype: str, n: int) -> str:
121    if n == 1:
122        return buffer_scalar_type(dtype)
123
124    if dtype == "float":
125        return f"vec{n}"
126    elif dtype == "half":
127        return f"f16vec{n}"
128    elif dtype == "int":
129        return f"ivec{n}"
130    elif dtype == "int8":
131        return f"i8vec{n}"
132    elif dtype == "uint8":
133        return f"u8vec{n}"
134
135    raise AssertionError(f"Invalid dtype: {dtype}")
136
137
138def texel_type(dtype: str) -> str:
139    image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
140    if image_format[-1] == "f":
141        return "vec4"
142    elif image_format[-2] == "ui":
143        return "uvec4"
144    elif image_format[-1] == "i":
145        return "ivec4"
146    raise AssertionError(f"Invalid image format: {image_format}")
147
148
149def gvec_type(dtype: str, n: int) -> str:
150    gvec4_type = texel_type(dtype)
151    return gvec4_type[:-1] + str(n)
152
153
154def texel_component_type(dtype: str) -> str:
155    vec4_type = texel_type(dtype)
156    if vec4_type[:3] == "vec":
157        return "float"
158    elif vec4_type[:4] == "ivec":
159        return "int"
160    elif vec4_type[:4] == "uvec":
161        return "uint"
162    raise AssertionError(f"Invalid vec4 type: {vec4_type}")
163
164
165def texel_load_type(dtype: str, storage_type: str) -> str:
166    if storage_type.lower() == "buffer":
167        return buffer_gvec_type(dtype, 4)
168    else:
169        return texel_type(dtype)
170
171
172def texel_load_component_type(dtype: str, storage_type: str) -> str:
173    if storage_type.lower() == "buffer":
174        return buffer_scalar_type(dtype)
175    else:
176        return texel_component_type(dtype)
177
178
179def get_access_qualifier(access_type: Optional[str]) -> str:
180    if access_type is None:
181        return ""
182    if access_type.lower() == "r":
183        return "readonly"
184    if access_type.lower() == "w":
185        return "writeonly"
186    if access_type.lower() == "rw":
187        return ""
188
189    raise AssertionError(f"Invalid access type: {access_type}")
190
191
192def get_slot_val(slot: Union[int, List[int]]) -> int:
193    if isinstance(slot, list):
194        return slot[0]
195    return slot
196
197
198def layout_declare_buffer(
199    slot: Union[int, List[int]],
200    access_type: str,
201    var_name: str,
202    dtype: str,
203    precision: str = "PRECISION",
204    is_scalar_array: bool = True,
205) -> str:
206    array_type = buffer_gvec_type(dtype, 4)
207    if is_scalar_array:
208        array_type = buffer_scalar_type(dtype)
209
210    out_str = f"""
211layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{
212    {array_type} {var_name}[];
213}};
214"""
215
216    if isinstance(slot, list):
217        slot[0] = slot[0] + 1
218    return out_str
219
220
221def layout_declare_image(
222    slot: Union[int, List[int]],
223    access_type: str,
224    var_name: str,
225    dtype: str,
226    precision: str = "PRECISION",
227    image_ndim: int = 3,
228) -> str:
229    image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
230    image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype]
231
232    ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};"
233
234    if isinstance(slot, list):
235        slot[0] = slot[0] + 1
236    return ret_str
237
238
239def layout_declare_sampler(
240    slot: Union[int, List[int]],
241    access_type: str,
242    var_name: str,
243    dtype: str,
244    precision: str = "PRECISION",
245    access_qualifier: Optional[str] = None,
246    image_ndim: int = 3,
247) -> str:
248    sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype]
249
250    ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};"
251
252    if isinstance(slot, list):
253        slot[0] = slot[0] + 1
254    return ret_str
255
256
257def layout_declare_tensor(
258    slot: Union[int, List[int]],
259    access_type: str,
260    var_name: str,
261    dtype: str,
262    storage_type: str,
263    is_scalar_array: bool = True,
264    precision: str = "PRECISION",
265) -> str:
266    assert storage_type.lower() in ["buffer", "texture3d", "texture2d"]
267
268    image_ndim = 3
269    if storage_type.lower() == "texture2d":
270        image_ndim = 2
271
272    # Create buffer binding
273    if storage_type.lower() == "buffer":
274        return layout_declare_buffer(
275            slot,
276            access_type,
277            var_name,
278            dtype,
279            precision,
280            is_scalar_array=is_scalar_array,
281        )
282
283    # Create image/sampler binding
284    if access_type.lower() == "r":
285        return layout_declare_sampler(
286            slot, access_type, var_name, dtype, precision, image_ndim=image_ndim
287        )
288    else:
289        return layout_declare_image(
290            slot, access_type, var_name, dtype, precision, image_ndim=image_ndim
291        )
292
293
294def layout_declare_ubo(
295    slot: Union[int, List[int]], *args, precision: str = "PRECISION"
296) -> str:
297    assert len(args) % 2 == 0
298
299    var_list = list(zip(args[::2], args[1::2]))
300
301    ubo_name = ""
302    for _, var_name in var_list:
303        ubo_name += var_name + "_"
304
305    out_str = f"""
306layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{
307"""
308    for type_name, var_name in var_list:
309        out_str += f"  {type_name} {var_name};\n"
310    out_str += "};"
311
312    if isinstance(slot, list):
313        slot[0] = slot[0] + 1
314    return out_str
315
316
317def layout_declare_spec_const(
318    slot: Union[int, List[int]],
319    type_name: str,
320    var_name: str,
321    initial_val: Optional[str] = None,
322) -> str:
323    assert type_name in ["int", "uint", "float", "bool"]
324
325    out_str = f"layout(constant_id = {get_slot_val(slot)}) const {type_name} {var_name}"
326    if initial_val is not None:
327        out_str += f" = {initial_val}"
328    out_str += ";"
329
330    if isinstance(slot, list):
331        slot[0] = slot[0] + 1
332    return out_str
333
334
335def define_active_storage_type(storage_type: str):
336    if storage_type.lower() == "buffer":
337        return "#define USING_BUFFER"
338    elif storage_type.lower() == "texture3d":
339        return "#define USING_TEXTURE3D"
340    elif storage_type.lower() == "texture2d":
341        return "#define USING_TEXTURE2D"
342    else:
343        raise AssertionError(f"Invalid storage type: {storage_type}")
344
345
346def define_required_extensions(dtypes: Union[str, List[str]]):
347    out_str = "\n"
348    dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]
349
350    for dtype in dtype_list:
351        nbit = None
352        glsl_type = None
353        if dtype == "half":
354            nbit = "16bit"
355            glsl_type = "float16"
356        elif dtype == "int16" or dtype == "uint16":
357            nbit = "16bit"
358            glsl_type = "int16"
359        elif dtype == "int8" or dtype == "uint8":
360            nbit = "8bit"
361            glsl_type = "int8"
362
363        if nbit is not None and glsl_type is not None:
364            out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
365            out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
366
367    return out_str
368
369
370UTILITY_FNS: Dict[str, Any] = {
371    "macro_define": define_variable,
372    "get_pos": {
373        3: lambda pos: pos,
374        2: lambda pos: f"{pos}.xy",
375    },
376    "buffer_scalar_type": buffer_scalar_type,
377    "buffer_gvec_type": buffer_gvec_type,
378    "texel_type": texel_type,
379    "gvec_type": gvec_type,
380    "texel_component_type": texel_component_type,
381    "texel_load_type": texel_load_type,
382    "texel_load_component_type": texel_load_component_type,
383    "layout_declare_buffer": layout_declare_buffer,
384    "layout_declare_image": layout_declare_image,
385    "layout_declare_sampler": layout_declare_sampler,
386    "layout_declare_tensor": layout_declare_tensor,
387    "layout_declare_ubo": layout_declare_ubo,
388    "layout_declare_spec_const": layout_declare_spec_const,
389    "define_active_storage_type": define_active_storage_type,
390    "define_required_extensions": define_required_extensions,
391}
392
393
394def extract_filename(path: str, keep_ext: bool = True) -> Any:
395    if keep_ext:
396        return os.path.basename(path)
397    else:
398        return os.path.basename(path).split(".")[0]
399
400
401############################
402#  SPIR-V Code Generation  #
403############################
404
405
406# https://gist.github.com/pypt/94d747fe5180851196eb
407class UniqueKeyLoader(Loader):
408    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
409        if not isinstance(node, MappingNode):
410            raise ConstructorError(
411                None,
412                None,
413                f"expected a mapping node, but found {node.id}",
414                node.start_mark,
415            )
416        mapping = {}
417        for key_node, value_node in node.value:
418            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
419            try:
420                hash(key)
421            except TypeError as e:
422                raise ConstructorError(
423                    "while constructing a mapping",
424                    node.start_mark,
425                    "found unacceptable key ",
426                    key_node.start_mark,
427                ) from e
428            # check for duplicate keys
429            if key in mapping:
430                raise ConstructorError(
431                    "while constructing a mapping",
432                    node.start_mark,
433                    "found duplicate key",
434                    key_node.start_mark,
435                )
436            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
437            mapping[key] = value
438        return mapping
439
440
441# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
442def extract_leading_whitespace(line: str) -> str:
443    match = re.match(r"\s*", line)
444    return match.group(0) if match else ""
445
446
447# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
448def escape(line: str) -> str:
449    output_parts = []
450    while "${" in line:
451        start_pos = line.index("${")
452        end_pos = line.index("}", start_pos + 2)
453        if start_pos != 0:
454            output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"')
455        output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")")
456        line = line[end_pos + 1 :]
457    if line:
458        output_parts.append('"' + line.replace('"', '\\"') + '"')
459    return " + ".join(output_parts)
460
461
462# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
463def preprocess(
464    input_text: str, variables: Dict[str, Any], input_path: str = "codegen"
465) -> str:
466    input_lines = input_text.splitlines()
467    python_lines = []
468
469    blank_lines = 0
470
471    last_indent = ""
472
473    # List of tuples (total_index, python_indent)
474    indent_stack = [("", "")]
475
476    # Indicates whether this is the first line inside Python
477    # code block (i.e. for, while, if, elif, else)
478    python_block_start = True
479    for input_line in input_lines:
480        if input_line == "":
481            blank_lines += 1
482            continue
483        # Skip lint markers.
484        if "LINT" in input_line:
485            continue
486
487        input_indent = extract_leading_whitespace(input_line)
488        if python_block_start:
489            assert input_indent.startswith(last_indent)
490            extra_python_indent = input_indent[len(last_indent) :]
491            python_indent = indent_stack[-1][1] + extra_python_indent
492            indent_stack.append((input_indent, python_indent))
493            assert input_indent.startswith(indent_stack[-1][0])
494        else:
495            while not input_indent.startswith(indent_stack[-1][0]):
496                del indent_stack[-1]
497        python_block_start = False
498
499        python_indent = indent_stack[-1][1]
500        stripped_input_line = input_line.strip()
501        if stripped_input_line.startswith("$") and not stripped_input_line.startswith(
502            "${"
503        ):
504            if stripped_input_line.endswith(":"):
505                python_block_start = True
506            while blank_lines != 0:
507                python_lines.append(python_indent + "print(file=OUT_STREAM)")
508                blank_lines -= 1
509            python_lines.append(python_indent + stripped_input_line.replace("$", ""))
510        else:
511            assert input_line.startswith(python_indent)
512            while blank_lines != 0:
513                python_lines.append(python_indent + "print(file=OUT_STREAM)")
514                blank_lines -= 1
515            python_lines.append(
516                python_indent
517                + "print(%s, file=OUT_STREAM)"
518                % escape(input_line[len(python_indent) :])
519            )
520        last_indent = input_indent
521
522    while blank_lines != 0:
523        python_lines.append(python_indent + "print(file=OUT_STREAM)")
524        blank_lines -= 1
525
526    exec_globals = dict(variables)
527    output_stream = io.StringIO()
528    exec_globals["OUT_STREAM"] = output_stream
529
530    python_bytecode = compile("\n".join(python_lines), input_path, "exec")
531    exec(python_bytecode, exec_globals)
532
533    return output_stream.getvalue()
534
535
536class SPVGenerator:
537    def __init__(
538        self,
539        src_dir_paths: Union[str, List[str]],
540        env: Dict[Any, Any],
541        glslc_path: Optional[str],
542        glslc_flags: str = "",
543        replace_u16vecn: bool = False,
544    ) -> None:
545        if isinstance(src_dir_paths, str):
546            self.src_dir_paths = [src_dir_paths]
547        else:
548            self.src_dir_paths = src_dir_paths
549
550        self.env = env
551        self.glslc_path = glslc_path
552        self.glslc_flags = glslc_flags
553        self.replace_u16vecn = replace_u16vecn
554
555        self.glsl_src_files: Dict[str, str] = {}
556        self.template_yaml_files: List[str] = []
557
558        self.addSrcAndYamlFiles(self.src_dir_paths)
559        self.shader_template_params: Dict[Any, Any] = {}
560        for yaml_file in self.template_yaml_files:
561            self.parseTemplateYaml(yaml_file)
562
563        self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
564        self.constructOutputMap()
565
566    def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
567        for src_path in src_dir_paths:
568            # Collect glsl source files
569            glsl_files = glob.glob(
570                os.path.join(src_path, "**", "*.glsl*"), recursive=True
571            )
572            for file in glsl_files:
573                if len(file) > 1:
574                    self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
575            # Collect template yaml files
576            yaml_files = glob.glob(
577                os.path.join(src_path, "**", "*.yaml"), recursive=True
578            )
579            for file in yaml_files:
580                if len(file) > 1:
581                    self.template_yaml_files.append(file)
582
583    def generateVariantCombinations(
584        self,
585        iterated_params: Dict[str, Any],
586        exclude_params: Optional[Set[str]] = None,
587    ) -> List[Any]:
588        if exclude_params is None:
589            exclude_params = set()
590        all_iterated_params = []
591        for param_name, value_list in iterated_params.items():
592            if param_name not in exclude_params:
593                param_values = []
594                for value in value_list:
595                    if "RANGE" in value:
596                        value_range = value["RANGE"]
597                        suffix = value.get("SUFFIX", "")
598                        if isinstance(value_range, list) and len(value_range) == 2:
599                            for i in range(value_range[0], value_range[1] + 1):
600                                curr_suffix = (
601                                    suffix + "_" + str(i) if suffix else str(i)
602                                )
603                                param_values.append((param_name, curr_suffix, i))
604                        else:
605                            raise ValueError(
606                                f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)."
607                            )
608
609                    elif "VALUE" in value:
610                        suffix = value.get("SUFFIX", value["VALUE"])
611                        param_values.append((param_name, suffix, value["VALUE"]))
612
613                    else:
614                        raise KeyError(
615                            "Parameter must be 'VALUE: string' or 'RANGE: [a, b]'"
616                        )
617
618                all_iterated_params.append(param_values)
619
620        return list(product(*all_iterated_params))
621
622    def parseTemplateYaml(self, yaml_file: str) -> None:
623        with open(yaml_file) as f:
624            contents = yaml.load(f, Loader=UniqueKeyLoader)
625            for template_name, params_dict in contents.items():
626                if template_name in self.shader_template_params:
627                    raise KeyError(f"{template_name} params file is defined twice")
628
629                default_params = params_dict["parameter_names_with_default_values"]
630                params_names = set(default_params.keys()).union({"NAME"})
631
632                self.shader_template_params[template_name] = []
633
634                default_iterated_params = params_dict.get(
635                    "generate_variant_forall", None
636                )
637
638                for variant in params_dict["shader_variants"]:
639                    variant_params_names = set(variant.keys())
640                    invalid_keys = (
641                        variant_params_names
642                        - params_names
643                        - {"generate_variant_forall"}
644                    )
645                    assert len(invalid_keys) == 0
646
647                    iterated_params = variant.get(
648                        "generate_variant_forall", default_iterated_params
649                    )
650
651                    if iterated_params is not None:
652                        variant_combinations = self.generateVariantCombinations(
653                            iterated_params, variant_params_names
654                        )
655
656                        for combination in variant_combinations:
657                            default_params_copy = copy.deepcopy(default_params)
658                            for key in variant:
659                                if key != "generate_variant_forall":
660                                    default_params_copy[key] = variant[key]
661
662                            variant_name = variant["NAME"]
663                            for param_value in combination:
664                                default_params_copy[param_value[0]] = param_value[2]
665                                if len(str(param_value[1])) > 0:
666                                    variant_name = f"{variant_name}_{param_value[1]}"
667
668                            default_params_copy["NAME"] = variant_name
669
670                            self.shader_template_params[template_name].append(
671                                default_params_copy
672                            )
673                    else:
674                        default_params_copy = copy.deepcopy(default_params)
675                        for key in variant:
676                            default_params_copy[key] = variant[key]
677
678                        self.shader_template_params[template_name].append(
679                            default_params_copy
680                        )
681
682    def create_shader_params(
683        self, variant_params: Optional[Dict[str, Any]] = None
684    ) -> Dict[str, str]:
685        if variant_params is None:
686            variant_params = {}
687        shader_params = copy.deepcopy(self.env)
688        for key, value in variant_params.items():
689            shader_params[key] = value
690
691        return shader_params
692
693    def constructOutputMap(self) -> None:
694        for shader_name, params in self.shader_template_params.items():
695            for variant in params:
696                source_glsl = self.glsl_src_files[shader_name]
697
698                self.output_shader_map[variant["NAME"]] = (
699                    source_glsl,
700                    self.create_shader_params(variant),
701                )
702
703        for shader_name, source_glsl in self.glsl_src_files.items():
704            if shader_name not in self.shader_template_params:
705                self.output_shader_map[shader_name] = (
706                    source_glsl,
707                    self.create_shader_params(),
708                )
709
710    def maybe_replace_u16vecn(self, input_text: str) -> str:
711        """
712        There is a latency benefit to using u16vecn variables to store texture position
713        variables instead of ivecn, likely due to reduced register pressure. However,
714        SwiftShader does not support 16 bit integer types in shaders, so this is a crude
715        way to fallback to using ivecn to store texture positions so that testing with
716        SwiftShader is still possible.
717        """
718        if not self.replace_u16vecn:
719            return input_text
720        if "codegen-nosub" in input_text:
721            return input_text
722
723        input_text = input_text.replace("u16vec", "ivec")
724        input_text = input_text.replace("uint16_t", "int")
725        return input_text
726
727    def generateSPV(self, output_dir: str) -> Dict[str, str]:
728        output_file_map = {}
729
730        def process_shader(shader_paths_pair):
731            shader_name = shader_paths_pair[0]
732
733            source_glsl = shader_paths_pair[1][0]
734            shader_params = shader_paths_pair[1][1]
735
736            with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
737                input_text = input_file.read()
738                input_text = self.maybe_replace_u16vecn(input_text)
739                output_text = preprocess(input_text, shader_params)
740
741            glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
742            with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
743                output_file.write(output_text)
744
745            # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
746            # This is mainly for testing purposes.
747            if self.glslc_path is not None:
748                spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
749
750                cmd = (
751                    [
752                        self.glslc_path,
753                        "-fshader-stage=compute",
754                        glsl_out_path,
755                        "-o",
756                        spv_out_path,
757                        "--target-env=vulkan1.1",
758                        "-Werror",
759                    ]
760                    + [
761                        arg
762                        for src_dir_path in self.src_dir_paths
763                        for arg in ["-I", src_dir_path]
764                    ]
765                    + self.glslc_flags.split()
766                )
767
768                subprocess.check_call(cmd)
769
770                return (spv_out_path, glsl_out_path)
771
772        # Parallelize shader compilation as much as possible to optimize build time.
773        with ThreadPool(os.cpu_count()) as pool:
774            for spv_out_path, glsl_out_path in pool.map(
775                process_shader, self.output_shader_map.items()
776            ):
777                output_file_map[spv_out_path] = glsl_out_path
778
779        return output_file_map
780
781
782##############################################
783#  Shader Info and Shader Registry Handling  #
784##############################################
785
786
787@dataclass
788class ShaderInfo:
789    tile_size: List[int]
790    layouts: List[str]
791    weight_storage_type: str = ""
792    bias_storage_type: str = ""
793    register_for: Optional[Tuple[str, List[str]]] = None
794
795
796def getName(filePath: str) -> str:
797    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
798
799
800def isDescriptorLine(lineStr: str) -> bool:
801    descriptorLineId = r"^layout\(set"
802    return re.search(descriptorLineId, lineStr) is not None
803
804
805def isTileSizeLine(lineStr: str) -> bool:
806    tile_size_id = r"^ \* TILE_SIZE = \("
807    return re.search(tile_size_id, lineStr) is not None
808
809
810def findTileSizes(lineStr: str) -> List[int]:
811    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
812    matches = re.search(tile_size_id, lineStr)
813    if matches is None:
814        raise AssertionError("matches is None in findTileSizes")
815    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
816
817
818def isWeightStorageTypeLine(lineStr: str) -> bool:
819    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
820    return re.search(weight_storage_id, lineStr) is not None
821
822
823def getWeightStorageType(lineStr: str) -> str:
824    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
825    matches = re.search(weight_storage_id, lineStr)
826    if matches is None:
827        raise AssertionError("matches is None in getWeightStorageType")
828    return matches.group(1)
829
830
831def isBiasStorageTypeLine(lineStr: str) -> bool:
832    weight_storage_id = r"^ \* BIAS_STORAGE = "
833    return re.search(weight_storage_id, lineStr) is not None
834
835
836def getBiasStorageType(lineStr: str) -> str:
837    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
838    matches = re.search(weight_storage_id, lineStr)
839    if matches is None:
840        raise AssertionError("matches is None in getBiasStorageType")
841    return matches.group(1)
842
843
844def isRegisterForLine(lineStr: str) -> bool:
845    # Check for Shader Name and a list of at least one Registry Key
846    register_for_id = (
847        r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
848    )
849    return re.search(register_for_id, lineStr) is not None
850
851
852def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
853    register_for_pattern = r"'([A-Za-z0-9_]+)'"
854    matches = re.findall(register_for_pattern, lineStr)
855    if matches is None:
856        raise AssertionError("matches is None in getBiasStorageType")
857    matches_list = list(matches)
858    return (matches_list[0], matches_list[1:])
859
860
861typeIdMapping = {
862    r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
863    r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
864    r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
865    r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
866}
867
868
869def determineDescriptorType(lineStr: str) -> str:
870    for identifier, typeNum in typeIdMapping.items():
871        if re.search(identifier, lineStr):
872            return typeNum
873    raise AssertionError(
874        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
875    )
876
877
878def getShaderInfo(srcFilePath: str) -> ShaderInfo:
879    shader_info = ShaderInfo([], [], "")
880    with open(srcFilePath) as srcFile:
881        for line in srcFile:
882            if isDescriptorLine(line):
883                shader_info.layouts.append(determineDescriptorType(line))
884            if isTileSizeLine(line):
885                shader_info.tile_size = findTileSizes(line)
886            if isWeightStorageTypeLine(line):
887                shader_info.weight_storage_type = getWeightStorageType(line)
888            if isBiasStorageTypeLine(line):
889                shader_info.bias_storage_type = getBiasStorageType(line)
890            if isRegisterForLine(line):
891                shader_info.register_for = findRegisterFor(line)
892
893    return shader_info
894
895
896##########################
897#  C++ File Generation  #
898#########################
899
900cpp_template = """
901#include <executorch/backends/vulkan/runtime/api/ShaderRegistry.h>
902#include <stdint.h>
903#include <vector>
904
905using namespace vkcompute;
906
907namespace at {{
908namespace native {{
909namespace vulkan {{
910
911namespace {{
912
913{spv_bin_arrays}
914
915}}
916
917static void register_fn() {{
918
919{register_shader_infos}
920
921{shader_info_registry}
922
923}}
924
925static const api::ShaderRegisterInit register_shaders(&register_fn);
926
927}}
928}}
929}}
930
931"""
932
933
934def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
935    with open(spvPath, "rb") as fr:
936        next_bin = array.array("I", fr.read())
937        sizeBytes = 4 * len(next_bin)
938        spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
939            name,
940            textwrap.indent(",\n".join(str(x) for x in next_bin), "  "),
941        )
942
943    return sizeBytes, spv_bin_str
944
945
946def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
947    tile_size = (
948        f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
949        if (len(shader_info.tile_size) > 0)
950        else "{1, 1, 1}"
951    )
952
953    shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
954
955    shader_info_args = [
956        f'"{name}"',
957        f"{name}_bin",
958        str(sizeBytes),
959        shader_info_layouts,
960        tile_size,
961    ]
962
963    shader_info_str = textwrap.indent(
964        "api::shader_registry().register_shader(\n  vkapi::ShaderInfo(\n{args}));\n".format(
965            args=textwrap.indent(",\n".join(shader_info_args), "     "),
966        ),
967        "    ",
968    )
969
970    return shader_info_str
971
972
973def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
974    if shader_info.register_for is None:
975        return ""
976
977    (op_name, registry_keys) = shader_info.register_for
978    shader_dispatch_str = ""
979    for registry_key in registry_keys:
980        shader_dispatch_str = textwrap.indent(
981            f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
982            "    ",
983        )
984
985    return shader_dispatch_str
986
987
988def genCppFiles(
989    spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
990) -> None:
991    spv_bin_strs = []
992    register_shader_info_strs = []
993    shader_registry_strs = []
994
995    for spvPath, srcPath in spv_files.items():
996        name = getName(spvPath).replace("_spv", "")
997
998        sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
999        spv_bin_strs.append(spv_bin_str)
1000
1001        shader_info = getShaderInfo(srcPath)
1002
1003        register_shader_info_strs.append(
1004            generateShaderInfoStr(shader_info, name, sizeBytes)
1005        )
1006
1007        if shader_info.register_for is not None:
1008            shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
1009
1010    spv_bin_arrays = "\n".join(spv_bin_strs)
1011    register_shader_infos = "\n".join(register_shader_info_strs)
1012    shader_info_registry = "\n".join(shader_registry_strs)
1013
1014    cpp = cpp_template.format(
1015        spv_bin_arrays=spv_bin_arrays,
1016        register_shader_infos=register_shader_infos,
1017        shader_info_registry=shader_info_registry,
1018    )
1019
1020    with open(cpp_src_file_path, "w") as fw:
1021        fw.write(cpp)
1022
1023
1024##########
1025#  Main  #
1026##########
1027
1028
1029def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
1030    d = {}
1031    if items:
1032        for item in items:
1033            tokens = item.split("=")
1034            key = tokens[0].strip()
1035            value = tokens[1].strip()
1036            d[key] = value
1037    return d
1038
1039
1040def main(argv: List[str]) -> int:
1041    parser = argparse.ArgumentParser(description="")
1042    parser.add_argument(
1043        "-i",
1044        "--glsl-paths",
1045        nargs="+",
1046        help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
1047        default=["."],
1048    )
1049    parser.add_argument("-c", "--glslc-path", required=True, help="")
1050    parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
1051    parser.add_argument("-o", "--output-path", required=True, help="")
1052    parser.add_argument("--replace-u16vecn", action="store_true", default=False)
1053    parser.add_argument("--optimize_size", action="store_true", help="")
1054    parser.add_argument("--optimize", action="store_true", help="")
1055    parser.add_argument(
1056        "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
1057    )
1058    options = parser.parse_args()
1059
1060    env = DEFAULT_ENV
1061    env.update(TYPE_MAPPINGS)
1062    env.update(UTILITY_FNS)
1063
1064    for key, value in parse_arg_env(options.env).items():
1065        env[key] = value
1066
1067    if not os.path.exists(options.output_path):
1068        os.makedirs(options.output_path)
1069
1070    if not os.path.exists(options.tmp_dir_path):
1071        os.makedirs(options.tmp_dir_path)
1072
1073    glslc_flags = ""
1074    if options.optimize_size:
1075        glslc_flags += "-Os"
1076    elif options.optimize:
1077        glslc_flags += "-O"
1078
1079    shader_generator = SPVGenerator(
1080        options.glsl_paths,
1081        env,
1082        options.glslc_path,
1083        glslc_flags=glslc_flags,
1084        replace_u16vecn=options.replace_u16vecn,
1085    )
1086    output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
1087
1088    genCppFiles(
1089        output_spv_files,
1090        f"{options.output_path}/{CPP_H_NAME}",
1091        f"{options.output_path}/{CPP_SRC_NAME}",
1092    )
1093
1094    return 0
1095
1096
1097def invoke_main() -> None:
1098    sys.exit(main(sys.argv))
1099
1100
1101if __name__ == "__main__":
1102    invoke_main()  # pragma: no cover
1103