xref: /aosp_15_r20/external/pytorch/tools/gen_vulkan_spv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport argparse
6*da0073e9SAndroid Build Coastguard Workerimport array
7*da0073e9SAndroid Build Coastguard Workerimport codecs
8*da0073e9SAndroid Build Coastguard Workerimport copy
9*da0073e9SAndroid Build Coastguard Workerimport glob
10*da0073e9SAndroid Build Coastguard Workerimport io
11*da0073e9SAndroid Build Coastguard Workerimport os
12*da0073e9SAndroid Build Coastguard Workerimport re
13*da0073e9SAndroid Build Coastguard Workerimport sys
14*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workersys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
17*da0073e9SAndroid Build Coastguard Workerimport subprocess
18*da0073e9SAndroid Build Coastguard Workerimport textwrap
19*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerimport yaml
23*da0073e9SAndroid Build Coastguard Workerfrom yaml.constructor import ConstructorError
24*da0073e9SAndroid Build Coastguard Workerfrom yaml.nodes import MappingNode
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workertry:
27*da0073e9SAndroid Build Coastguard Worker    from yaml import CLoader as Loader
28*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
29*da0073e9SAndroid Build Coastguard Worker    from yaml import Loader  # type: ignore[assignment, misc]
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard WorkerCPP_H_NAME = "spv.h"
32*da0073e9SAndroid Build Coastguard WorkerCPP_SRC_NAME = "spv.cpp"
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerDEFAULT_ENV: dict[str, Any] = {
35*da0073e9SAndroid Build Coastguard Worker    "PRECISION": "highp",
36*da0073e9SAndroid Build Coastguard Worker    "FLOAT_IMAGE_FORMAT": "rgba16f",
37*da0073e9SAndroid Build Coastguard Worker    "INT_IMAGE_FORMAT": "rgba32i",
38*da0073e9SAndroid Build Coastguard Worker    "UINT_IMAGE_FORMAT": "rgba32ui",
39*da0073e9SAndroid Build Coastguard Worker}
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerTYPES_ENV: dict[str, Any] = {
42*da0073e9SAndroid Build Coastguard Worker    "IMAGE_FORMAT": {
43*da0073e9SAndroid Build Coastguard Worker        "float": "rgba32f",
44*da0073e9SAndroid Build Coastguard Worker        "half": "rgba16f",
45*da0073e9SAndroid Build Coastguard Worker        "int": "rgba32i",
46*da0073e9SAndroid Build Coastguard Worker        "uint": "rgba32ui",
47*da0073e9SAndroid Build Coastguard Worker        "int8": "rgba8i",
48*da0073e9SAndroid Build Coastguard Worker        "uint8": "rgba8ui",
49*da0073e9SAndroid Build Coastguard Worker    },
50*da0073e9SAndroid Build Coastguard Worker    "IMAGE_T": {
51*da0073e9SAndroid Build Coastguard Worker        3: {
52*da0073e9SAndroid Build Coastguard Worker            "float": "image3D",
53*da0073e9SAndroid Build Coastguard Worker            "half": "image3D",
54*da0073e9SAndroid Build Coastguard Worker            "int": "iimage3D",
55*da0073e9SAndroid Build Coastguard Worker            "uint": "uimage3D",
56*da0073e9SAndroid Build Coastguard Worker        },
57*da0073e9SAndroid Build Coastguard Worker        2: {
58*da0073e9SAndroid Build Coastguard Worker            "float": "image2D",
59*da0073e9SAndroid Build Coastguard Worker            "half": "image2D",
60*da0073e9SAndroid Build Coastguard Worker            "int": "iimage2D",
61*da0073e9SAndroid Build Coastguard Worker            "uint": "uimage2D",
62*da0073e9SAndroid Build Coastguard Worker        },
63*da0073e9SAndroid Build Coastguard Worker    },
64*da0073e9SAndroid Build Coastguard Worker    "SAMPLER_T": {
65*da0073e9SAndroid Build Coastguard Worker        3: {
66*da0073e9SAndroid Build Coastguard Worker            "float": "sampler3D",
67*da0073e9SAndroid Build Coastguard Worker            "half": "sampler3D",
68*da0073e9SAndroid Build Coastguard Worker            "int": "isampler3D",
69*da0073e9SAndroid Build Coastguard Worker            "uint": "usampler3D",
70*da0073e9SAndroid Build Coastguard Worker        },
71*da0073e9SAndroid Build Coastguard Worker        2: {
72*da0073e9SAndroid Build Coastguard Worker            "float": "sampler2D",
73*da0073e9SAndroid Build Coastguard Worker            "half": "sampler2D",
74*da0073e9SAndroid Build Coastguard Worker            "int": "isampler2D",
75*da0073e9SAndroid Build Coastguard Worker            "uint": "usampler2D",
76*da0073e9SAndroid Build Coastguard Worker        },
77*da0073e9SAndroid Build Coastguard Worker    },
78*da0073e9SAndroid Build Coastguard Worker    "VEC4_T": {
79*da0073e9SAndroid Build Coastguard Worker        "float": "vec4",
80*da0073e9SAndroid Build Coastguard Worker        "half": "vec4",
81*da0073e9SAndroid Build Coastguard Worker        "int": "ivec4",
82*da0073e9SAndroid Build Coastguard Worker        "uint": "uvec4",
83*da0073e9SAndroid Build Coastguard Worker        "int8": "vec4",
84*da0073e9SAndroid Build Coastguard Worker        "uint8": "uvec4",
85*da0073e9SAndroid Build Coastguard Worker    },
86*da0073e9SAndroid Build Coastguard Worker    "T": {
87*da0073e9SAndroid Build Coastguard Worker        "float": "float",
88*da0073e9SAndroid Build Coastguard Worker        "half": "float",
89*da0073e9SAndroid Build Coastguard Worker        "int": "int",
90*da0073e9SAndroid Build Coastguard Worker        "uint": "uint",
91*da0073e9SAndroid Build Coastguard Worker        "int8": "int",
92*da0073e9SAndroid Build Coastguard Worker        "uint8": "uint8",
93*da0073e9SAndroid Build Coastguard Worker    },
94*da0073e9SAndroid Build Coastguard Worker}
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard WorkerFUNCS_ENV: dict[str, Any] = {
97*da0073e9SAndroid Build Coastguard Worker    "GET_POS": {
98*da0073e9SAndroid Build Coastguard Worker        3: lambda pos: pos,
99*da0073e9SAndroid Build Coastguard Worker        2: lambda pos: f"{pos}.xy",
100*da0073e9SAndroid Build Coastguard Worker    }
101*da0073e9SAndroid Build Coastguard Worker}
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef extract_filename(path: str, keep_ext: bool = True) -> Any:
105*da0073e9SAndroid Build Coastguard Worker    if keep_ext:
106*da0073e9SAndroid Build Coastguard Worker        return os.path.basename(path)
107*da0073e9SAndroid Build Coastguard Worker    else:
108*da0073e9SAndroid Build Coastguard Worker        return os.path.basename(path).split(".")[0]
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker############################
112*da0073e9SAndroid Build Coastguard Worker#  SPIR-V Code Generation  #
113*da0073e9SAndroid Build Coastguard Worker############################
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker# https://gist.github.com/pypt/94d747fe5180851196eb
117*da0073e9SAndroid Build Coastguard Workerclass UniqueKeyLoader(Loader):
118*da0073e9SAndroid Build Coastguard Worker    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
119*da0073e9SAndroid Build Coastguard Worker        if not isinstance(node, MappingNode):
120*da0073e9SAndroid Build Coastguard Worker            raise ConstructorError(
121*da0073e9SAndroid Build Coastguard Worker                None,
122*da0073e9SAndroid Build Coastguard Worker                None,
123*da0073e9SAndroid Build Coastguard Worker                f"expected a mapping node, but found {node.id}",
124*da0073e9SAndroid Build Coastguard Worker                node.start_mark,
125*da0073e9SAndroid Build Coastguard Worker            )
126*da0073e9SAndroid Build Coastguard Worker        mapping = {}
127*da0073e9SAndroid Build Coastguard Worker        for key_node, value_node in node.value:
128*da0073e9SAndroid Build Coastguard Worker            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
129*da0073e9SAndroid Build Coastguard Worker            try:
130*da0073e9SAndroid Build Coastguard Worker                hash(key)
131*da0073e9SAndroid Build Coastguard Worker            except TypeError as e:
132*da0073e9SAndroid Build Coastguard Worker                raise ConstructorError(
133*da0073e9SAndroid Build Coastguard Worker                    "while constructing a mapping",
134*da0073e9SAndroid Build Coastguard Worker                    node.start_mark,
135*da0073e9SAndroid Build Coastguard Worker                    "found unacceptable key ",
136*da0073e9SAndroid Build Coastguard Worker                    key_node.start_mark,
137*da0073e9SAndroid Build Coastguard Worker                ) from e
138*da0073e9SAndroid Build Coastguard Worker            # check for duplicate keys
139*da0073e9SAndroid Build Coastguard Worker            if key in mapping:
140*da0073e9SAndroid Build Coastguard Worker                raise ConstructorError(
141*da0073e9SAndroid Build Coastguard Worker                    "while constructing a mapping",
142*da0073e9SAndroid Build Coastguard Worker                    node.start_mark,
143*da0073e9SAndroid Build Coastguard Worker                    "found duplicate key",
144*da0073e9SAndroid Build Coastguard Worker                    key_node.start_mark,
145*da0073e9SAndroid Build Coastguard Worker                )
146*da0073e9SAndroid Build Coastguard Worker            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
147*da0073e9SAndroid Build Coastguard Worker            mapping[key] = value
148*da0073e9SAndroid Build Coastguard Worker        return mapping
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
152*da0073e9SAndroid Build Coastguard Workerdef extract_leading_whitespace(line: str) -> str:
153*da0073e9SAndroid Build Coastguard Worker    match = re.match(r"\s*", line)
154*da0073e9SAndroid Build Coastguard Worker    return match.group(0) if match else ""
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
158*da0073e9SAndroid Build Coastguard Workerdef escape(line: str) -> str:
159*da0073e9SAndroid Build Coastguard Worker    output_parts = []
160*da0073e9SAndroid Build Coastguard Worker    while "${" in line:
161*da0073e9SAndroid Build Coastguard Worker        start_pos = line.index("${")
162*da0073e9SAndroid Build Coastguard Worker        end_pos = line.index("}", start_pos + 2)
163*da0073e9SAndroid Build Coastguard Worker        if start_pos != 0:
164*da0073e9SAndroid Build Coastguard Worker            output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"')
165*da0073e9SAndroid Build Coastguard Worker        output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")")
166*da0073e9SAndroid Build Coastguard Worker        line = line[end_pos + 1 :]
167*da0073e9SAndroid Build Coastguard Worker    if line:
168*da0073e9SAndroid Build Coastguard Worker        output_parts.append('"' + line.replace('"', '\\"') + '"')
169*da0073e9SAndroid Build Coastguard Worker    return " + ".join(output_parts)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
173*da0073e9SAndroid Build Coastguard Workerdef preprocess(
174*da0073e9SAndroid Build Coastguard Worker    input_text: str, variables: dict[str, Any], input_path: str = "codegen"
175*da0073e9SAndroid Build Coastguard Worker) -> str:
176*da0073e9SAndroid Build Coastguard Worker    input_lines = input_text.splitlines()
177*da0073e9SAndroid Build Coastguard Worker    python_lines = []
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    blank_lines = 0
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    last_indent = ""
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    # List of tuples (total_index, python_indent)
184*da0073e9SAndroid Build Coastguard Worker    indent_stack = [("", "")]
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    # Indicates whether this is the first line inside Python
187*da0073e9SAndroid Build Coastguard Worker    # code block (i.e. for, while, if, elif, else)
188*da0073e9SAndroid Build Coastguard Worker    python_block_start = True
189*da0073e9SAndroid Build Coastguard Worker    for i, input_line in enumerate(input_lines):
190*da0073e9SAndroid Build Coastguard Worker        if input_line == "":
191*da0073e9SAndroid Build Coastguard Worker            blank_lines += 1
192*da0073e9SAndroid Build Coastguard Worker            continue
193*da0073e9SAndroid Build Coastguard Worker        # Skip lint markers.
194*da0073e9SAndroid Build Coastguard Worker        if "LINT" in input_line:
195*da0073e9SAndroid Build Coastguard Worker            continue
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        input_indent = extract_leading_whitespace(input_line)
198*da0073e9SAndroid Build Coastguard Worker        if python_block_start:
199*da0073e9SAndroid Build Coastguard Worker            assert input_indent.startswith(last_indent)
200*da0073e9SAndroid Build Coastguard Worker            extra_python_indent = input_indent[len(last_indent) :]
201*da0073e9SAndroid Build Coastguard Worker            python_indent = indent_stack[-1][1] + extra_python_indent
202*da0073e9SAndroid Build Coastguard Worker            indent_stack.append((input_indent, python_indent))
203*da0073e9SAndroid Build Coastguard Worker            assert input_indent.startswith(indent_stack[-1][0])
204*da0073e9SAndroid Build Coastguard Worker        else:
205*da0073e9SAndroid Build Coastguard Worker            while not input_indent.startswith(indent_stack[-1][0]):
206*da0073e9SAndroid Build Coastguard Worker                del indent_stack[-1]
207*da0073e9SAndroid Build Coastguard Worker        python_block_start = False
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        python_indent = indent_stack[-1][1]
210*da0073e9SAndroid Build Coastguard Worker        stripped_input_line = input_line.strip()
211*da0073e9SAndroid Build Coastguard Worker        if stripped_input_line.startswith("$") and not stripped_input_line.startswith(
212*da0073e9SAndroid Build Coastguard Worker            "${"
213*da0073e9SAndroid Build Coastguard Worker        ):
214*da0073e9SAndroid Build Coastguard Worker            if stripped_input_line.endswith(":"):
215*da0073e9SAndroid Build Coastguard Worker                python_block_start = True
216*da0073e9SAndroid Build Coastguard Worker            while blank_lines != 0:
217*da0073e9SAndroid Build Coastguard Worker                python_lines.append(python_indent + "print(file=OUT_STREAM)")
218*da0073e9SAndroid Build Coastguard Worker                blank_lines -= 1
219*da0073e9SAndroid Build Coastguard Worker            python_lines.append(python_indent + stripped_input_line.replace("$", ""))
220*da0073e9SAndroid Build Coastguard Worker        else:
221*da0073e9SAndroid Build Coastguard Worker            assert input_line.startswith(python_indent)
222*da0073e9SAndroid Build Coastguard Worker            while blank_lines != 0:
223*da0073e9SAndroid Build Coastguard Worker                python_lines.append(python_indent + "print(file=OUT_STREAM)")
224*da0073e9SAndroid Build Coastguard Worker                blank_lines -= 1
225*da0073e9SAndroid Build Coastguard Worker            python_lines.append(
226*da0073e9SAndroid Build Coastguard Worker                python_indent
227*da0073e9SAndroid Build Coastguard Worker                + f"print({escape(input_line[len(python_indent) :])}, file=OUT_STREAM)"
228*da0073e9SAndroid Build Coastguard Worker            )
229*da0073e9SAndroid Build Coastguard Worker        last_indent = input_indent
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    while blank_lines != 0:
232*da0073e9SAndroid Build Coastguard Worker        python_lines.append(python_indent + "print(file=OUT_STREAM)")
233*da0073e9SAndroid Build Coastguard Worker        blank_lines -= 1
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    exec_globals = dict(variables)
236*da0073e9SAndroid Build Coastguard Worker    output_stream = io.StringIO()
237*da0073e9SAndroid Build Coastguard Worker    exec_globals["OUT_STREAM"] = output_stream
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    python_bytecode = compile("\n".join(python_lines), input_path, "exec")
240*da0073e9SAndroid Build Coastguard Worker    exec(python_bytecode, exec_globals)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    return output_stream.getvalue()
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Workerclass SPVGenerator:
246*da0073e9SAndroid Build Coastguard Worker    def __init__(
247*da0073e9SAndroid Build Coastguard Worker        self,
248*da0073e9SAndroid Build Coastguard Worker        src_dir_paths: str | list[str],
249*da0073e9SAndroid Build Coastguard Worker        env: dict[Any, Any],
250*da0073e9SAndroid Build Coastguard Worker        glslc_path: str | None,
251*da0073e9SAndroid Build Coastguard Worker    ) -> None:
252*da0073e9SAndroid Build Coastguard Worker        if isinstance(src_dir_paths, str):
253*da0073e9SAndroid Build Coastguard Worker            self.src_dir_paths = [src_dir_paths]
254*da0073e9SAndroid Build Coastguard Worker        else:
255*da0073e9SAndroid Build Coastguard Worker            self.src_dir_paths = src_dir_paths
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        self.env = env
258*da0073e9SAndroid Build Coastguard Worker        self.glslc_path = glslc_path
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        self.glsl_src_files: dict[str, str] = {}
261*da0073e9SAndroid Build Coastguard Worker        self.template_yaml_files: list[str] = []
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        self.addSrcAndYamlFiles(self.src_dir_paths)
264*da0073e9SAndroid Build Coastguard Worker        self.shader_template_params: dict[Any, Any] = {}
265*da0073e9SAndroid Build Coastguard Worker        for yaml_file in self.template_yaml_files:
266*da0073e9SAndroid Build Coastguard Worker            self.parseTemplateYaml(yaml_file)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
269*da0073e9SAndroid Build Coastguard Worker        self.constructOutputMap()
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
272*da0073e9SAndroid Build Coastguard Worker        for src_path in src_dir_paths:
273*da0073e9SAndroid Build Coastguard Worker            # Collect glsl source files
274*da0073e9SAndroid Build Coastguard Worker            glsl_files = glob.glob(
275*da0073e9SAndroid Build Coastguard Worker                os.path.join(src_path, "**", "*.glsl*"), recursive=True
276*da0073e9SAndroid Build Coastguard Worker            )
277*da0073e9SAndroid Build Coastguard Worker            for file in glsl_files:
278*da0073e9SAndroid Build Coastguard Worker                if len(file) > 1:
279*da0073e9SAndroid Build Coastguard Worker                    self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
280*da0073e9SAndroid Build Coastguard Worker            # Collect template yaml files
281*da0073e9SAndroid Build Coastguard Worker            yaml_files = glob.glob(
282*da0073e9SAndroid Build Coastguard Worker                os.path.join(src_path, "**", "*.yaml"), recursive=True
283*da0073e9SAndroid Build Coastguard Worker            )
284*da0073e9SAndroid Build Coastguard Worker            for file in yaml_files:
285*da0073e9SAndroid Build Coastguard Worker                if len(file) > 1:
286*da0073e9SAndroid Build Coastguard Worker                    self.template_yaml_files.append(file)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def generateVariantCombinations(
289*da0073e9SAndroid Build Coastguard Worker        self,
290*da0073e9SAndroid Build Coastguard Worker        iterated_params: dict[str, Any],
291*da0073e9SAndroid Build Coastguard Worker        exclude_params: set[str] | None = None,
292*da0073e9SAndroid Build Coastguard Worker    ) -> list[Any]:
293*da0073e9SAndroid Build Coastguard Worker        if exclude_params is None:
294*da0073e9SAndroid Build Coastguard Worker            exclude_params = set()
295*da0073e9SAndroid Build Coastguard Worker        all_iterated_params = []
296*da0073e9SAndroid Build Coastguard Worker        for param_name, value_list in iterated_params.items():
297*da0073e9SAndroid Build Coastguard Worker            if param_name not in exclude_params:
298*da0073e9SAndroid Build Coastguard Worker                param_values = []
299*da0073e9SAndroid Build Coastguard Worker                for value in value_list:
300*da0073e9SAndroid Build Coastguard Worker                    suffix = value.get("SUFFIX", value["VALUE"])
301*da0073e9SAndroid Build Coastguard Worker                    param_values.append((param_name, suffix, value["VALUE"]))
302*da0073e9SAndroid Build Coastguard Worker                all_iterated_params.append(param_values)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        return list(product(*all_iterated_params))
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    def parseTemplateYaml(self, yaml_file: str) -> None:
307*da0073e9SAndroid Build Coastguard Worker        with open(yaml_file) as f:
308*da0073e9SAndroid Build Coastguard Worker            contents = yaml.load(f, Loader=UniqueKeyLoader)
309*da0073e9SAndroid Build Coastguard Worker            for template_name, params_dict in contents.items():
310*da0073e9SAndroid Build Coastguard Worker                if template_name in self.shader_template_params:
311*da0073e9SAndroid Build Coastguard Worker                    raise KeyError(f"{template_name} params file is defined twice")
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker                default_params = params_dict["parameter_names_with_default_values"]
314*da0073e9SAndroid Build Coastguard Worker                params_names = set(default_params.keys()).union({"NAME"})
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker                self.shader_template_params[template_name] = []
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker                default_iterated_params = params_dict.get(
319*da0073e9SAndroid Build Coastguard Worker                    "generate_variant_forall", None
320*da0073e9SAndroid Build Coastguard Worker                )
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker                for variant in params_dict["shader_variants"]:
323*da0073e9SAndroid Build Coastguard Worker                    variant_params_names = set(variant.keys())
324*da0073e9SAndroid Build Coastguard Worker                    invalid_keys = (
325*da0073e9SAndroid Build Coastguard Worker                        variant_params_names
326*da0073e9SAndroid Build Coastguard Worker                        - params_names
327*da0073e9SAndroid Build Coastguard Worker                        - {"generate_variant_forall"}
328*da0073e9SAndroid Build Coastguard Worker                    )
329*da0073e9SAndroid Build Coastguard Worker                    assert len(invalid_keys) == 0
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker                    iterated_params = variant.get(
332*da0073e9SAndroid Build Coastguard Worker                        "generate_variant_forall", default_iterated_params
333*da0073e9SAndroid Build Coastguard Worker                    )
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker                    if iterated_params is not None:
336*da0073e9SAndroid Build Coastguard Worker                        variant_combinations = self.generateVariantCombinations(
337*da0073e9SAndroid Build Coastguard Worker                            iterated_params, variant_params_names
338*da0073e9SAndroid Build Coastguard Worker                        )
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker                        for combination in variant_combinations:
341*da0073e9SAndroid Build Coastguard Worker                            default_params_copy = copy.deepcopy(default_params)
342*da0073e9SAndroid Build Coastguard Worker                            for key in variant:
343*da0073e9SAndroid Build Coastguard Worker                                if key != "generate_variant_forall":
344*da0073e9SAndroid Build Coastguard Worker                                    default_params_copy[key] = variant[key]
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker                            variant_name = variant["NAME"]
347*da0073e9SAndroid Build Coastguard Worker                            for param_value in combination:
348*da0073e9SAndroid Build Coastguard Worker                                default_params_copy[param_value[0]] = param_value[2]
349*da0073e9SAndroid Build Coastguard Worker                                if len(param_value[1]) > 0:
350*da0073e9SAndroid Build Coastguard Worker                                    variant_name = f"{variant_name}_{param_value[1]}"
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker                            default_params_copy["NAME"] = variant_name
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker                            self.shader_template_params[template_name].append(
355*da0073e9SAndroid Build Coastguard Worker                                default_params_copy
356*da0073e9SAndroid Build Coastguard Worker                            )
357*da0073e9SAndroid Build Coastguard Worker                    else:
358*da0073e9SAndroid Build Coastguard Worker                        default_params_copy = copy.deepcopy(default_params)
359*da0073e9SAndroid Build Coastguard Worker                        for key in variant:
360*da0073e9SAndroid Build Coastguard Worker                            default_params_copy[key] = variant[key]
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker                        self.shader_template_params[template_name].append(
363*da0073e9SAndroid Build Coastguard Worker                            default_params_copy
364*da0073e9SAndroid Build Coastguard Worker                        )
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker    def create_shader_params(
367*da0073e9SAndroid Build Coastguard Worker        self, variant_params: dict[str, Any] | None = None
368*da0073e9SAndroid Build Coastguard Worker    ) -> dict[str, str]:
369*da0073e9SAndroid Build Coastguard Worker        if variant_params is None:
370*da0073e9SAndroid Build Coastguard Worker            variant_params = {}
371*da0073e9SAndroid Build Coastguard Worker        shader_params = copy.deepcopy(self.env)
372*da0073e9SAndroid Build Coastguard Worker        for key, value in variant_params.items():
373*da0073e9SAndroid Build Coastguard Worker            shader_params[key] = value
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        shader_dtype = shader_params.get("DTYPE", "float")
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        if shader_dtype == "int":
378*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
379*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "uint":
380*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
381*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "int32":
382*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = "rgba32i"
383*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "uint32":
384*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = "rgba32ui"
385*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "int8":
386*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = "rgba8i"
387*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "uint8":
388*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = "rgba8ui"
389*da0073e9SAndroid Build Coastguard Worker        elif shader_dtype == "float32":
390*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = "rgba32f"
391*da0073e9SAndroid Build Coastguard Worker        # Assume float by default
392*da0073e9SAndroid Build Coastguard Worker        else:
393*da0073e9SAndroid Build Coastguard Worker            shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker        return shader_params
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    def constructOutputMap(self) -> None:
398*da0073e9SAndroid Build Coastguard Worker        for shader_name, params in self.shader_template_params.items():
399*da0073e9SAndroid Build Coastguard Worker            for variant in params:
400*da0073e9SAndroid Build Coastguard Worker                source_glsl = self.glsl_src_files[shader_name]
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker                self.output_shader_map[variant["NAME"]] = (
403*da0073e9SAndroid Build Coastguard Worker                    source_glsl,
404*da0073e9SAndroid Build Coastguard Worker                    self.create_shader_params(variant),
405*da0073e9SAndroid Build Coastguard Worker                )
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        for shader_name, source_glsl in self.glsl_src_files.items():
408*da0073e9SAndroid Build Coastguard Worker            if shader_name not in self.shader_template_params:
409*da0073e9SAndroid Build Coastguard Worker                self.output_shader_map[shader_name] = (
410*da0073e9SAndroid Build Coastguard Worker                    source_glsl,
411*da0073e9SAndroid Build Coastguard Worker                    self.create_shader_params(),
412*da0073e9SAndroid Build Coastguard Worker                )
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    def generateSPV(self, output_dir: str) -> dict[str, str]:
415*da0073e9SAndroid Build Coastguard Worker        output_file_map = {}
416*da0073e9SAndroid Build Coastguard Worker        for shader_name in self.output_shader_map:
417*da0073e9SAndroid Build Coastguard Worker            source_glsl = self.output_shader_map[shader_name][0]
418*da0073e9SAndroid Build Coastguard Worker            shader_params = self.output_shader_map[shader_name][1]
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker            with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
421*da0073e9SAndroid Build Coastguard Worker                input_text = input_file.read()
422*da0073e9SAndroid Build Coastguard Worker                output_text = preprocess(input_text, shader_params)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker            glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
425*da0073e9SAndroid Build Coastguard Worker            with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
426*da0073e9SAndroid Build Coastguard Worker                output_file.write(output_text)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker            # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
429*da0073e9SAndroid Build Coastguard Worker            # This is mainly for testing purposes.
430*da0073e9SAndroid Build Coastguard Worker            if self.glslc_path is not None:
431*da0073e9SAndroid Build Coastguard Worker                spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker                cmd = [
434*da0073e9SAndroid Build Coastguard Worker                    self.glslc_path,
435*da0073e9SAndroid Build Coastguard Worker                    "-fshader-stage=compute",
436*da0073e9SAndroid Build Coastguard Worker                    glsl_out_path,
437*da0073e9SAndroid Build Coastguard Worker                    "-o",
438*da0073e9SAndroid Build Coastguard Worker                    spv_out_path,
439*da0073e9SAndroid Build Coastguard Worker                    "--target-env=vulkan1.0",
440*da0073e9SAndroid Build Coastguard Worker                    "-Werror",
441*da0073e9SAndroid Build Coastguard Worker                ] + [
442*da0073e9SAndroid Build Coastguard Worker                    arg
443*da0073e9SAndroid Build Coastguard Worker                    for src_dir_path in self.src_dir_paths
444*da0073e9SAndroid Build Coastguard Worker                    for arg in ["-I", src_dir_path]
445*da0073e9SAndroid Build Coastguard Worker                ]
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker                print("glslc cmd:", cmd)
448*da0073e9SAndroid Build Coastguard Worker                subprocess.check_call(cmd)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker                output_file_map[spv_out_path] = glsl_out_path
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        return output_file_map
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker##############################################
456*da0073e9SAndroid Build Coastguard Worker#  Shader Info and Shader Registry Handling  #
457*da0073e9SAndroid Build Coastguard Worker##############################################
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker@dataclass
461*da0073e9SAndroid Build Coastguard Workerclass ShaderInfo:
462*da0073e9SAndroid Build Coastguard Worker    tile_size: list[int]
463*da0073e9SAndroid Build Coastguard Worker    layouts: list[str]
464*da0073e9SAndroid Build Coastguard Worker    weight_storage_type: str = ""
465*da0073e9SAndroid Build Coastguard Worker    bias_storage_type: str = ""
466*da0073e9SAndroid Build Coastguard Worker    register_for: tuple[str, list[str]] | None = None
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Workerdef getName(filePath: str) -> str:
470*da0073e9SAndroid Build Coastguard Worker    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Workerdef isDescriptorLine(lineStr: str) -> bool:
474*da0073e9SAndroid Build Coastguard Worker    descriptorLineId = r"^layout\(set"
475*da0073e9SAndroid Build Coastguard Worker    return re.search(descriptorLineId, lineStr) is not None
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Workerdef isTileSizeLine(lineStr: str) -> bool:
479*da0073e9SAndroid Build Coastguard Worker    tile_size_id = r"^ \* TILE_SIZE = \("
480*da0073e9SAndroid Build Coastguard Worker    return re.search(tile_size_id, lineStr) is not None
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Workerdef findTileSizes(lineStr: str) -> list[int]:
484*da0073e9SAndroid Build Coastguard Worker    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
485*da0073e9SAndroid Build Coastguard Worker    matches = re.search(tile_size_id, lineStr)
486*da0073e9SAndroid Build Coastguard Worker    if matches is None:
487*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("matches is None in findTileSizes")
488*da0073e9SAndroid Build Coastguard Worker    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Workerdef isWeightStorageTypeLine(lineStr: str) -> bool:
492*da0073e9SAndroid Build Coastguard Worker    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
493*da0073e9SAndroid Build Coastguard Worker    return re.search(weight_storage_id, lineStr) is not None
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Workerdef getWeightStorageType(lineStr: str) -> str:
497*da0073e9SAndroid Build Coastguard Worker    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
498*da0073e9SAndroid Build Coastguard Worker    matches = re.search(weight_storage_id, lineStr)
499*da0073e9SAndroid Build Coastguard Worker    if matches is None:
500*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("matches is None in getWeightStorageType")
501*da0073e9SAndroid Build Coastguard Worker    return matches.group(1)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Workerdef isBiasStorageTypeLine(lineStr: str) -> bool:
505*da0073e9SAndroid Build Coastguard Worker    weight_storage_id = r"^ \* BIAS_STORAGE = "
506*da0073e9SAndroid Build Coastguard Worker    return re.search(weight_storage_id, lineStr) is not None
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Workerdef getBiasStorageType(lineStr: str) -> str:
510*da0073e9SAndroid Build Coastguard Worker    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
511*da0073e9SAndroid Build Coastguard Worker    matches = re.search(weight_storage_id, lineStr)
512*da0073e9SAndroid Build Coastguard Worker    if matches is None:
513*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("matches is None in getBiasStorageType")
514*da0073e9SAndroid Build Coastguard Worker    return matches.group(1)
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Workerdef isRegisterForLine(lineStr: str) -> bool:
518*da0073e9SAndroid Build Coastguard Worker    # Check for Shader Name and a list of at least one Registry Key
519*da0073e9SAndroid Build Coastguard Worker    register_for_id = (
520*da0073e9SAndroid Build Coastguard Worker        r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
521*da0073e9SAndroid Build Coastguard Worker    )
522*da0073e9SAndroid Build Coastguard Worker    return re.search(register_for_id, lineStr) is not None
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Workerdef findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
526*da0073e9SAndroid Build Coastguard Worker    register_for_pattern = r"'([A-Za-z0-9_]+)'"
527*da0073e9SAndroid Build Coastguard Worker    matches = re.findall(register_for_pattern, lineStr)
528*da0073e9SAndroid Build Coastguard Worker    if matches is None:
529*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("matches is None in getBiasStorageType")
530*da0073e9SAndroid Build Coastguard Worker    matches_list = list(matches)
531*da0073e9SAndroid Build Coastguard Worker    return (matches_list[0], matches_list[1:])
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard WorkertypeIdMapping = {
535*da0073e9SAndroid Build Coastguard Worker    r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
536*da0073e9SAndroid Build Coastguard Worker    r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
537*da0073e9SAndroid Build Coastguard Worker    r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
538*da0073e9SAndroid Build Coastguard Worker    r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
539*da0073e9SAndroid Build Coastguard Worker}
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard WorkerstorageTypeToEnum = {
542*da0073e9SAndroid Build Coastguard Worker    "TEXTURE_2D": "api::StorageType::TEXTURE_2D",
543*da0073e9SAndroid Build Coastguard Worker    "TEXTURE_3D": "api::StorageType::TEXTURE_3D",
544*da0073e9SAndroid Build Coastguard Worker    "BUFFER": "api::StorageType::BUFFER",
545*da0073e9SAndroid Build Coastguard Worker    "": "api::StorageType::UNKNOWN",
546*da0073e9SAndroid Build Coastguard Worker}
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Workerdef determineDescriptorType(lineStr: str) -> str:
550*da0073e9SAndroid Build Coastguard Worker    for identifier, typeNum in typeIdMapping.items():
551*da0073e9SAndroid Build Coastguard Worker        if re.search(identifier, lineStr):
552*da0073e9SAndroid Build Coastguard Worker            return typeNum
553*da0073e9SAndroid Build Coastguard Worker    raise AssertionError(
554*da0073e9SAndroid Build Coastguard Worker        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
555*da0073e9SAndroid Build Coastguard Worker    )
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Workerdef getShaderInfo(srcFilePath: str) -> ShaderInfo:
559*da0073e9SAndroid Build Coastguard Worker    shader_info = ShaderInfo([], [], "")
560*da0073e9SAndroid Build Coastguard Worker    with open(srcFilePath) as srcFile:
561*da0073e9SAndroid Build Coastguard Worker        for line in srcFile:
562*da0073e9SAndroid Build Coastguard Worker            if isDescriptorLine(line):
563*da0073e9SAndroid Build Coastguard Worker                shader_info.layouts.append(determineDescriptorType(line))
564*da0073e9SAndroid Build Coastguard Worker            if isTileSizeLine(line):
565*da0073e9SAndroid Build Coastguard Worker                shader_info.tile_size = findTileSizes(line)
566*da0073e9SAndroid Build Coastguard Worker            if isWeightStorageTypeLine(line):
567*da0073e9SAndroid Build Coastguard Worker                shader_info.weight_storage_type = getWeightStorageType(line)
568*da0073e9SAndroid Build Coastguard Worker            if isBiasStorageTypeLine(line):
569*da0073e9SAndroid Build Coastguard Worker                shader_info.bias_storage_type = getBiasStorageType(line)
570*da0073e9SAndroid Build Coastguard Worker            if isRegisterForLine(line):
571*da0073e9SAndroid Build Coastguard Worker                shader_info.register_for = findRegisterFor(line)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    return shader_info
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker##########################
577*da0073e9SAndroid Build Coastguard Worker#  C++ File Generation  #
578*da0073e9SAndroid Build Coastguard Worker#########################
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Workercpp_template = """
581*da0073e9SAndroid Build Coastguard Worker#include <ATen/native/vulkan/api/ShaderRegistry.h>
582*da0073e9SAndroid Build Coastguard Worker#include <stdint.h>
583*da0073e9SAndroid Build Coastguard Worker#include <vector>
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Workerusing namespace at::native::vulkan;
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Workernamespace at {{
588*da0073e9SAndroid Build Coastguard Workernamespace native {{
589*da0073e9SAndroid Build Coastguard Workernamespace vulkan {{
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Workernamespace {{
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker{spv_bin_arrays}
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker}}
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Workerstatic void register_fn() {{
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker{register_shader_infos}
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker{shader_info_registry}
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker}}
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Workerstatic const api::ShaderRegisterInit register_shaders(&register_fn);
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker}}
608*da0073e9SAndroid Build Coastguard Worker}}
609*da0073e9SAndroid Build Coastguard Worker}}
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker"""
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Workerdef generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
615*da0073e9SAndroid Build Coastguard Worker    with open(spvPath, "rb") as fr:
616*da0073e9SAndroid Build Coastguard Worker        next_bin = array.array("I", fr.read())
617*da0073e9SAndroid Build Coastguard Worker        sizeBytes = 4 * len(next_bin)
618*da0073e9SAndroid Build Coastguard Worker        spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
619*da0073e9SAndroid Build Coastguard Worker            name,
620*da0073e9SAndroid Build Coastguard Worker            textwrap.indent(",\n".join(str(x) for x in next_bin), "  "),
621*da0073e9SAndroid Build Coastguard Worker        )
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    return sizeBytes, spv_bin_str
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Workerdef generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
627*da0073e9SAndroid Build Coastguard Worker    tile_size = (
628*da0073e9SAndroid Build Coastguard Worker        f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
629*da0073e9SAndroid Build Coastguard Worker        if (len(shader_info.tile_size) > 0)
630*da0073e9SAndroid Build Coastguard Worker        else "std::vector<uint32_t>()"
631*da0073e9SAndroid Build Coastguard Worker    )
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker    shader_info_args = [
636*da0073e9SAndroid Build Coastguard Worker        f'"{name}"',
637*da0073e9SAndroid Build Coastguard Worker        f"{name}_bin",
638*da0073e9SAndroid Build Coastguard Worker        str(sizeBytes),
639*da0073e9SAndroid Build Coastguard Worker        shader_info_layouts,
640*da0073e9SAndroid Build Coastguard Worker        tile_size,
641*da0073e9SAndroid Build Coastguard Worker        storageTypeToEnum[shader_info.weight_storage_type],
642*da0073e9SAndroid Build Coastguard Worker        storageTypeToEnum[shader_info.bias_storage_type],
643*da0073e9SAndroid Build Coastguard Worker    ]
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker    shader_info_str = textwrap.indent(
646*da0073e9SAndroid Build Coastguard Worker        "api::shader_registry().register_shader(\n  api::ShaderInfo(\n{args}));\n".format(
647*da0073e9SAndroid Build Coastguard Worker            args=textwrap.indent(",\n".join(shader_info_args), "     "),
648*da0073e9SAndroid Build Coastguard Worker        ),
649*da0073e9SAndroid Build Coastguard Worker        "    ",
650*da0073e9SAndroid Build Coastguard Worker    )
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    return shader_info_str
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Workerdef generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
656*da0073e9SAndroid Build Coastguard Worker    if shader_info.register_for is None:
657*da0073e9SAndroid Build Coastguard Worker        return ""
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker    (op_name, registry_keys) = shader_info.register_for
660*da0073e9SAndroid Build Coastguard Worker    for registry_key in registry_keys:
661*da0073e9SAndroid Build Coastguard Worker        shader_dispatch_str = textwrap.indent(
662*da0073e9SAndroid Build Coastguard Worker            f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
663*da0073e9SAndroid Build Coastguard Worker            "    ",
664*da0073e9SAndroid Build Coastguard Worker        )
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    return shader_dispatch_str
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Workerdef genCppFiles(
670*da0073e9SAndroid Build Coastguard Worker    spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
671*da0073e9SAndroid Build Coastguard Worker) -> None:
672*da0073e9SAndroid Build Coastguard Worker    spv_bin_strs = []
673*da0073e9SAndroid Build Coastguard Worker    register_shader_info_strs = []
674*da0073e9SAndroid Build Coastguard Worker    shader_registry_strs = []
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    for spvPath, srcPath in spv_files.items():
677*da0073e9SAndroid Build Coastguard Worker        name = getName(spvPath).replace("_spv", "")
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
680*da0073e9SAndroid Build Coastguard Worker        spv_bin_strs.append(spv_bin_str)
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        shader_info = getShaderInfo(srcPath)
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker        register_shader_info_strs.append(
685*da0073e9SAndroid Build Coastguard Worker            generateShaderInfoStr(shader_info, name, sizeBytes)
686*da0073e9SAndroid Build Coastguard Worker        )
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        if shader_info.register_for is not None:
689*da0073e9SAndroid Build Coastguard Worker            shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    spv_bin_arrays = "\n".join(spv_bin_strs)
692*da0073e9SAndroid Build Coastguard Worker    register_shader_infos = "\n".join(register_shader_info_strs)
693*da0073e9SAndroid Build Coastguard Worker    shader_info_registry = "\n".join(shader_registry_strs)
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    cpp = cpp_template.format(
696*da0073e9SAndroid Build Coastguard Worker        spv_bin_arrays=spv_bin_arrays,
697*da0073e9SAndroid Build Coastguard Worker        register_shader_infos=register_shader_infos,
698*da0073e9SAndroid Build Coastguard Worker        shader_info_registry=shader_info_registry,
699*da0073e9SAndroid Build Coastguard Worker    )
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    with open(cpp_src_file_path, "w") as fw:
702*da0073e9SAndroid Build Coastguard Worker        fw.write(cpp)
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker##########
706*da0073e9SAndroid Build Coastguard Worker#  Main  #
707*da0073e9SAndroid Build Coastguard Worker##########
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Workerdef parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
711*da0073e9SAndroid Build Coastguard Worker    d = {}
712*da0073e9SAndroid Build Coastguard Worker    if items:
713*da0073e9SAndroid Build Coastguard Worker        for item in items:
714*da0073e9SAndroid Build Coastguard Worker            tokens = item.split("=")
715*da0073e9SAndroid Build Coastguard Worker            key = tokens[0].strip()
716*da0073e9SAndroid Build Coastguard Worker            value = tokens[1].strip()
717*da0073e9SAndroid Build Coastguard Worker            d[key] = value
718*da0073e9SAndroid Build Coastguard Worker    return d
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Workerdef main(argv: list[str]) -> int:
722*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="")
723*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
724*da0073e9SAndroid Build Coastguard Worker        "-i",
725*da0073e9SAndroid Build Coastguard Worker        "--glsl-paths",
726*da0073e9SAndroid Build Coastguard Worker        nargs="+",
727*da0073e9SAndroid Build Coastguard Worker        help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
728*da0073e9SAndroid Build Coastguard Worker        default=["."],
729*da0073e9SAndroid Build Coastguard Worker    )
730*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("-c", "--glslc-path", required=True, help="")
731*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
732*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("-o", "--output-path", required=True, help="")
733*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
734*da0073e9SAndroid Build Coastguard Worker        "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
735*da0073e9SAndroid Build Coastguard Worker    )
736*da0073e9SAndroid Build Coastguard Worker    options = parser.parse_args()
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker    DEFAULT_ENV.update(TYPES_ENV)
739*da0073e9SAndroid Build Coastguard Worker    DEFAULT_ENV.update(FUNCS_ENV)
740*da0073e9SAndroid Build Coastguard Worker    env = DEFAULT_ENV
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    for key, value in parse_arg_env(options.env).items():
743*da0073e9SAndroid Build Coastguard Worker        env[key] = value
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(options.output_path):
746*da0073e9SAndroid Build Coastguard Worker        os.makedirs(options.output_path)
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(options.tmp_dir_path):
749*da0073e9SAndroid Build Coastguard Worker        os.makedirs(options.tmp_dir_path)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
752*da0073e9SAndroid Build Coastguard Worker    output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker    genCppFiles(
755*da0073e9SAndroid Build Coastguard Worker        output_spv_files,
756*da0073e9SAndroid Build Coastguard Worker        f"{options.output_path}/{CPP_H_NAME}",
757*da0073e9SAndroid Build Coastguard Worker        f"{options.output_path}/{CPP_SRC_NAME}",
758*da0073e9SAndroid Build Coastguard Worker    )
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker    return 0
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Workerdef invoke_main() -> None:
764*da0073e9SAndroid Build Coastguard Worker    sys.exit(main(sys.argv))
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
768*da0073e9SAndroid Build Coastguard Worker    invoke_main()  # pragma: no cover
769