1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport argparse 6*da0073e9SAndroid Build Coastguard Workerimport array 7*da0073e9SAndroid Build Coastguard Workerimport codecs 8*da0073e9SAndroid Build Coastguard Workerimport copy 9*da0073e9SAndroid Build Coastguard Workerimport glob 10*da0073e9SAndroid Build Coastguard Workerimport io 11*da0073e9SAndroid Build Coastguard Workerimport os 12*da0073e9SAndroid Build Coastguard Workerimport re 13*da0073e9SAndroid Build Coastguard Workerimport sys 14*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workersys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 17*da0073e9SAndroid Build Coastguard Workerimport subprocess 18*da0073e9SAndroid Build Coastguard Workerimport textwrap 19*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerimport yaml 23*da0073e9SAndroid Build Coastguard Workerfrom yaml.constructor import ConstructorError 24*da0073e9SAndroid Build Coastguard Workerfrom yaml.nodes import MappingNode 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workertry: 27*da0073e9SAndroid Build Coastguard Worker from yaml import CLoader as Loader 28*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 29*da0073e9SAndroid Build Coastguard Worker from yaml import Loader # type: ignore[assignment, misc] 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard WorkerCPP_H_NAME = "spv.h" 32*da0073e9SAndroid Build Coastguard WorkerCPP_SRC_NAME = "spv.cpp" 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerDEFAULT_ENV: dict[str, Any] = { 35*da0073e9SAndroid Build Coastguard Worker "PRECISION": "highp", 36*da0073e9SAndroid Build Coastguard Worker "FLOAT_IMAGE_FORMAT": "rgba16f", 37*da0073e9SAndroid Build Coastguard Worker "INT_IMAGE_FORMAT": "rgba32i", 38*da0073e9SAndroid Build Coastguard Worker "UINT_IMAGE_FORMAT": "rgba32ui", 39*da0073e9SAndroid Build Coastguard Worker} 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerTYPES_ENV: dict[str, Any] = { 42*da0073e9SAndroid Build Coastguard Worker "IMAGE_FORMAT": { 43*da0073e9SAndroid Build Coastguard Worker "float": "rgba32f", 44*da0073e9SAndroid Build Coastguard Worker "half": "rgba16f", 45*da0073e9SAndroid Build Coastguard Worker "int": "rgba32i", 46*da0073e9SAndroid Build Coastguard Worker "uint": "rgba32ui", 47*da0073e9SAndroid Build Coastguard Worker "int8": "rgba8i", 48*da0073e9SAndroid Build Coastguard Worker "uint8": "rgba8ui", 49*da0073e9SAndroid Build Coastguard Worker }, 50*da0073e9SAndroid Build Coastguard Worker "IMAGE_T": { 51*da0073e9SAndroid Build Coastguard Worker 3: { 52*da0073e9SAndroid Build Coastguard Worker "float": "image3D", 53*da0073e9SAndroid Build Coastguard Worker "half": "image3D", 54*da0073e9SAndroid Build Coastguard Worker "int": "iimage3D", 55*da0073e9SAndroid Build Coastguard Worker "uint": "uimage3D", 56*da0073e9SAndroid Build Coastguard Worker }, 57*da0073e9SAndroid Build Coastguard Worker 2: { 58*da0073e9SAndroid Build Coastguard Worker "float": "image2D", 59*da0073e9SAndroid Build Coastguard Worker "half": "image2D", 60*da0073e9SAndroid Build Coastguard Worker "int": "iimage2D", 61*da0073e9SAndroid Build Coastguard Worker "uint": "uimage2D", 62*da0073e9SAndroid Build Coastguard Worker }, 63*da0073e9SAndroid Build Coastguard Worker }, 64*da0073e9SAndroid Build Coastguard Worker "SAMPLER_T": { 65*da0073e9SAndroid Build Coastguard Worker 3: { 66*da0073e9SAndroid Build Coastguard Worker "float": "sampler3D", 67*da0073e9SAndroid Build Coastguard Worker "half": "sampler3D", 68*da0073e9SAndroid Build Coastguard Worker "int": "isampler3D", 69*da0073e9SAndroid Build Coastguard Worker "uint": "usampler3D", 70*da0073e9SAndroid Build Coastguard Worker }, 71*da0073e9SAndroid Build Coastguard Worker 2: { 72*da0073e9SAndroid Build Coastguard Worker "float": "sampler2D", 73*da0073e9SAndroid Build Coastguard Worker "half": "sampler2D", 74*da0073e9SAndroid Build Coastguard Worker "int": "isampler2D", 75*da0073e9SAndroid Build Coastguard Worker "uint": "usampler2D", 76*da0073e9SAndroid Build Coastguard Worker }, 77*da0073e9SAndroid Build Coastguard Worker }, 78*da0073e9SAndroid Build Coastguard Worker "VEC4_T": { 79*da0073e9SAndroid Build Coastguard Worker "float": "vec4", 80*da0073e9SAndroid Build Coastguard Worker "half": "vec4", 81*da0073e9SAndroid Build Coastguard Worker "int": "ivec4", 82*da0073e9SAndroid Build Coastguard Worker "uint": "uvec4", 83*da0073e9SAndroid Build Coastguard Worker "int8": "vec4", 84*da0073e9SAndroid Build Coastguard Worker "uint8": "uvec4", 85*da0073e9SAndroid Build Coastguard Worker }, 86*da0073e9SAndroid Build Coastguard Worker "T": { 87*da0073e9SAndroid Build Coastguard Worker "float": "float", 88*da0073e9SAndroid Build Coastguard Worker "half": "float", 89*da0073e9SAndroid Build Coastguard Worker "int": "int", 90*da0073e9SAndroid Build Coastguard Worker "uint": "uint", 91*da0073e9SAndroid Build Coastguard Worker "int8": "int", 92*da0073e9SAndroid Build Coastguard Worker "uint8": "uint8", 93*da0073e9SAndroid Build Coastguard Worker }, 94*da0073e9SAndroid Build Coastguard Worker} 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard WorkerFUNCS_ENV: dict[str, Any] = { 97*da0073e9SAndroid Build Coastguard Worker "GET_POS": { 98*da0073e9SAndroid Build Coastguard Worker 3: lambda pos: pos, 99*da0073e9SAndroid Build Coastguard Worker 2: lambda pos: f"{pos}.xy", 100*da0073e9SAndroid Build Coastguard Worker } 101*da0073e9SAndroid Build Coastguard Worker} 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef extract_filename(path: str, keep_ext: bool = True) -> Any: 105*da0073e9SAndroid Build Coastguard Worker if keep_ext: 106*da0073e9SAndroid Build Coastguard Worker return os.path.basename(path) 107*da0073e9SAndroid Build Coastguard Worker else: 108*da0073e9SAndroid Build Coastguard Worker return os.path.basename(path).split(".")[0] 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker############################ 112*da0073e9SAndroid Build Coastguard Worker# SPIR-V Code Generation # 113*da0073e9SAndroid Build Coastguard Worker############################ 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker# https://gist.github.com/pypt/94d747fe5180851196eb 117*da0073e9SAndroid Build Coastguard Workerclass UniqueKeyLoader(Loader): 118*da0073e9SAndroid Build Coastguard Worker def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] 119*da0073e9SAndroid Build Coastguard Worker if not isinstance(node, MappingNode): 120*da0073e9SAndroid Build Coastguard Worker raise ConstructorError( 121*da0073e9SAndroid Build Coastguard Worker None, 122*da0073e9SAndroid Build Coastguard Worker None, 123*da0073e9SAndroid Build Coastguard Worker f"expected a mapping node, but found {node.id}", 124*da0073e9SAndroid Build Coastguard Worker node.start_mark, 125*da0073e9SAndroid Build Coastguard Worker ) 126*da0073e9SAndroid Build Coastguard Worker mapping = {} 127*da0073e9SAndroid Build Coastguard Worker for key_node, value_node in node.value: 128*da0073e9SAndroid Build Coastguard Worker key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] 129*da0073e9SAndroid Build Coastguard Worker try: 130*da0073e9SAndroid Build Coastguard Worker hash(key) 131*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 132*da0073e9SAndroid Build Coastguard Worker raise ConstructorError( 133*da0073e9SAndroid Build Coastguard Worker "while constructing a mapping", 134*da0073e9SAndroid Build Coastguard Worker node.start_mark, 135*da0073e9SAndroid Build Coastguard Worker "found unacceptable key ", 136*da0073e9SAndroid Build Coastguard Worker key_node.start_mark, 137*da0073e9SAndroid Build Coastguard Worker ) from e 138*da0073e9SAndroid Build Coastguard Worker # check for duplicate keys 139*da0073e9SAndroid Build Coastguard Worker if key in mapping: 140*da0073e9SAndroid Build Coastguard Worker raise ConstructorError( 141*da0073e9SAndroid Build Coastguard Worker "while constructing a mapping", 142*da0073e9SAndroid Build Coastguard Worker node.start_mark, 143*da0073e9SAndroid Build Coastguard Worker "found duplicate key", 144*da0073e9SAndroid Build Coastguard Worker key_node.start_mark, 145*da0073e9SAndroid Build Coastguard Worker ) 146*da0073e9SAndroid Build Coastguard Worker value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] 147*da0073e9SAndroid Build Coastguard Worker mapping[key] = value 148*da0073e9SAndroid Build Coastguard Worker return mapping 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 152*da0073e9SAndroid Build Coastguard Workerdef extract_leading_whitespace(line: str) -> str: 153*da0073e9SAndroid Build Coastguard Worker match = re.match(r"\s*", line) 154*da0073e9SAndroid Build Coastguard Worker return match.group(0) if match else "" 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 158*da0073e9SAndroid Build Coastguard Workerdef escape(line: str) -> str: 159*da0073e9SAndroid Build Coastguard Worker output_parts = [] 160*da0073e9SAndroid Build Coastguard Worker while "${" in line: 161*da0073e9SAndroid Build Coastguard Worker start_pos = line.index("${") 162*da0073e9SAndroid Build Coastguard Worker end_pos = line.index("}", start_pos + 2) 163*da0073e9SAndroid Build Coastguard Worker if start_pos != 0: 164*da0073e9SAndroid Build Coastguard Worker output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"') 165*da0073e9SAndroid Build Coastguard Worker output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")") 166*da0073e9SAndroid Build Coastguard Worker line = line[end_pos + 1 :] 167*da0073e9SAndroid Build Coastguard Worker if line: 168*da0073e9SAndroid Build Coastguard Worker output_parts.append('"' + line.replace('"', '\\"') + '"') 169*da0073e9SAndroid Build Coastguard Worker return " + ".join(output_parts) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 173*da0073e9SAndroid Build Coastguard Workerdef preprocess( 174*da0073e9SAndroid Build Coastguard Worker input_text: str, variables: dict[str, Any], input_path: str = "codegen" 175*da0073e9SAndroid Build Coastguard Worker) -> str: 176*da0073e9SAndroid Build Coastguard Worker input_lines = input_text.splitlines() 177*da0073e9SAndroid Build Coastguard Worker python_lines = [] 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker blank_lines = 0 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker last_indent = "" 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker # List of tuples (total_index, python_indent) 184*da0073e9SAndroid Build Coastguard Worker indent_stack = [("", "")] 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker # Indicates whether this is the first line inside Python 187*da0073e9SAndroid Build Coastguard Worker # code block (i.e. for, while, if, elif, else) 188*da0073e9SAndroid Build Coastguard Worker python_block_start = True 189*da0073e9SAndroid Build Coastguard Worker for i, input_line in enumerate(input_lines): 190*da0073e9SAndroid Build Coastguard Worker if input_line == "": 191*da0073e9SAndroid Build Coastguard Worker blank_lines += 1 192*da0073e9SAndroid Build Coastguard Worker continue 193*da0073e9SAndroid Build Coastguard Worker # Skip lint markers. 194*da0073e9SAndroid Build Coastguard Worker if "LINT" in input_line: 195*da0073e9SAndroid Build Coastguard Worker continue 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker input_indent = extract_leading_whitespace(input_line) 198*da0073e9SAndroid Build Coastguard Worker if python_block_start: 199*da0073e9SAndroid Build Coastguard Worker assert input_indent.startswith(last_indent) 200*da0073e9SAndroid Build Coastguard Worker extra_python_indent = input_indent[len(last_indent) :] 201*da0073e9SAndroid Build Coastguard Worker python_indent = indent_stack[-1][1] + extra_python_indent 202*da0073e9SAndroid Build Coastguard Worker indent_stack.append((input_indent, python_indent)) 203*da0073e9SAndroid Build Coastguard Worker assert input_indent.startswith(indent_stack[-1][0]) 204*da0073e9SAndroid Build Coastguard Worker else: 205*da0073e9SAndroid Build Coastguard Worker while not input_indent.startswith(indent_stack[-1][0]): 206*da0073e9SAndroid Build Coastguard Worker del indent_stack[-1] 207*da0073e9SAndroid Build Coastguard Worker python_block_start = False 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker python_indent = indent_stack[-1][1] 210*da0073e9SAndroid Build Coastguard Worker stripped_input_line = input_line.strip() 211*da0073e9SAndroid Build Coastguard Worker if stripped_input_line.startswith("$") and not stripped_input_line.startswith( 212*da0073e9SAndroid Build Coastguard Worker "${" 213*da0073e9SAndroid Build Coastguard Worker ): 214*da0073e9SAndroid Build Coastguard Worker if stripped_input_line.endswith(":"): 215*da0073e9SAndroid Build Coastguard Worker python_block_start = True 216*da0073e9SAndroid Build Coastguard Worker while blank_lines != 0: 217*da0073e9SAndroid Build Coastguard Worker python_lines.append(python_indent + "print(file=OUT_STREAM)") 218*da0073e9SAndroid Build Coastguard Worker blank_lines -= 1 219*da0073e9SAndroid Build Coastguard Worker python_lines.append(python_indent + stripped_input_line.replace("$", "")) 220*da0073e9SAndroid Build Coastguard Worker else: 221*da0073e9SAndroid Build Coastguard Worker assert input_line.startswith(python_indent) 222*da0073e9SAndroid Build Coastguard Worker while blank_lines != 0: 223*da0073e9SAndroid Build Coastguard Worker python_lines.append(python_indent + "print(file=OUT_STREAM)") 224*da0073e9SAndroid Build Coastguard Worker blank_lines -= 1 225*da0073e9SAndroid Build Coastguard Worker python_lines.append( 226*da0073e9SAndroid Build Coastguard Worker python_indent 227*da0073e9SAndroid Build Coastguard Worker + f"print({escape(input_line[len(python_indent) :])}, file=OUT_STREAM)" 228*da0073e9SAndroid Build Coastguard Worker ) 229*da0073e9SAndroid Build Coastguard Worker last_indent = input_indent 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker while blank_lines != 0: 232*da0073e9SAndroid Build Coastguard Worker python_lines.append(python_indent + "print(file=OUT_STREAM)") 233*da0073e9SAndroid Build Coastguard Worker blank_lines -= 1 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker exec_globals = dict(variables) 236*da0073e9SAndroid Build Coastguard Worker output_stream = io.StringIO() 237*da0073e9SAndroid Build Coastguard Worker exec_globals["OUT_STREAM"] = output_stream 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker python_bytecode = compile("\n".join(python_lines), input_path, "exec") 240*da0073e9SAndroid Build Coastguard Worker exec(python_bytecode, exec_globals) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker return output_stream.getvalue() 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Workerclass SPVGenerator: 246*da0073e9SAndroid Build Coastguard Worker def __init__( 247*da0073e9SAndroid Build Coastguard Worker self, 248*da0073e9SAndroid Build Coastguard Worker src_dir_paths: str | list[str], 249*da0073e9SAndroid Build Coastguard Worker env: dict[Any, Any], 250*da0073e9SAndroid Build Coastguard Worker glslc_path: str | None, 251*da0073e9SAndroid Build Coastguard Worker ) -> None: 252*da0073e9SAndroid Build Coastguard Worker if isinstance(src_dir_paths, str): 253*da0073e9SAndroid Build Coastguard Worker self.src_dir_paths = [src_dir_paths] 254*da0073e9SAndroid Build Coastguard Worker else: 255*da0073e9SAndroid Build Coastguard Worker self.src_dir_paths = src_dir_paths 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker self.env = env 258*da0073e9SAndroid Build Coastguard Worker self.glslc_path = glslc_path 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker self.glsl_src_files: dict[str, str] = {} 261*da0073e9SAndroid Build Coastguard Worker self.template_yaml_files: list[str] = [] 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker self.addSrcAndYamlFiles(self.src_dir_paths) 264*da0073e9SAndroid Build Coastguard Worker self.shader_template_params: dict[Any, Any] = {} 265*da0073e9SAndroid Build Coastguard Worker for yaml_file in self.template_yaml_files: 266*da0073e9SAndroid Build Coastguard Worker self.parseTemplateYaml(yaml_file) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {} 269*da0073e9SAndroid Build Coastguard Worker self.constructOutputMap() 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None: 272*da0073e9SAndroid Build Coastguard Worker for src_path in src_dir_paths: 273*da0073e9SAndroid Build Coastguard Worker # Collect glsl source files 274*da0073e9SAndroid Build Coastguard Worker glsl_files = glob.glob( 275*da0073e9SAndroid Build Coastguard Worker os.path.join(src_path, "**", "*.glsl*"), recursive=True 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker for file in glsl_files: 278*da0073e9SAndroid Build Coastguard Worker if len(file) > 1: 279*da0073e9SAndroid Build Coastguard Worker self.glsl_src_files[extract_filename(file, keep_ext=False)] = file 280*da0073e9SAndroid Build Coastguard Worker # Collect template yaml files 281*da0073e9SAndroid Build Coastguard Worker yaml_files = glob.glob( 282*da0073e9SAndroid Build Coastguard Worker os.path.join(src_path, "**", "*.yaml"), recursive=True 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker for file in yaml_files: 285*da0073e9SAndroid Build Coastguard Worker if len(file) > 1: 286*da0073e9SAndroid Build Coastguard Worker self.template_yaml_files.append(file) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def generateVariantCombinations( 289*da0073e9SAndroid Build Coastguard Worker self, 290*da0073e9SAndroid Build Coastguard Worker iterated_params: dict[str, Any], 291*da0073e9SAndroid Build Coastguard Worker exclude_params: set[str] | None = None, 292*da0073e9SAndroid Build Coastguard Worker ) -> list[Any]: 293*da0073e9SAndroid Build Coastguard Worker if exclude_params is None: 294*da0073e9SAndroid Build Coastguard Worker exclude_params = set() 295*da0073e9SAndroid Build Coastguard Worker all_iterated_params = [] 296*da0073e9SAndroid Build Coastguard Worker for param_name, value_list in iterated_params.items(): 297*da0073e9SAndroid Build Coastguard Worker if param_name not in exclude_params: 298*da0073e9SAndroid Build Coastguard Worker param_values = [] 299*da0073e9SAndroid Build Coastguard Worker for value in value_list: 300*da0073e9SAndroid Build Coastguard Worker suffix = value.get("SUFFIX", value["VALUE"]) 301*da0073e9SAndroid Build Coastguard Worker param_values.append((param_name, suffix, value["VALUE"])) 302*da0073e9SAndroid Build Coastguard Worker all_iterated_params.append(param_values) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker return list(product(*all_iterated_params)) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def parseTemplateYaml(self, yaml_file: str) -> None: 307*da0073e9SAndroid Build Coastguard Worker with open(yaml_file) as f: 308*da0073e9SAndroid Build Coastguard Worker contents = yaml.load(f, Loader=UniqueKeyLoader) 309*da0073e9SAndroid Build Coastguard Worker for template_name, params_dict in contents.items(): 310*da0073e9SAndroid Build Coastguard Worker if template_name in self.shader_template_params: 311*da0073e9SAndroid Build Coastguard Worker raise KeyError(f"{template_name} params file is defined twice") 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker default_params = params_dict["parameter_names_with_default_values"] 314*da0073e9SAndroid Build Coastguard Worker params_names = set(default_params.keys()).union({"NAME"}) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker self.shader_template_params[template_name] = [] 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker default_iterated_params = params_dict.get( 319*da0073e9SAndroid Build Coastguard Worker "generate_variant_forall", None 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker for variant in params_dict["shader_variants"]: 323*da0073e9SAndroid Build Coastguard Worker variant_params_names = set(variant.keys()) 324*da0073e9SAndroid Build Coastguard Worker invalid_keys = ( 325*da0073e9SAndroid Build Coastguard Worker variant_params_names 326*da0073e9SAndroid Build Coastguard Worker - params_names 327*da0073e9SAndroid Build Coastguard Worker - {"generate_variant_forall"} 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker assert len(invalid_keys) == 0 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker iterated_params = variant.get( 332*da0073e9SAndroid Build Coastguard Worker "generate_variant_forall", default_iterated_params 333*da0073e9SAndroid Build Coastguard Worker ) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker if iterated_params is not None: 336*da0073e9SAndroid Build Coastguard Worker variant_combinations = self.generateVariantCombinations( 337*da0073e9SAndroid Build Coastguard Worker iterated_params, variant_params_names 338*da0073e9SAndroid Build Coastguard Worker ) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker for combination in variant_combinations: 341*da0073e9SAndroid Build Coastguard Worker default_params_copy = copy.deepcopy(default_params) 342*da0073e9SAndroid Build Coastguard Worker for key in variant: 343*da0073e9SAndroid Build Coastguard Worker if key != "generate_variant_forall": 344*da0073e9SAndroid Build Coastguard Worker default_params_copy[key] = variant[key] 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker variant_name = variant["NAME"] 347*da0073e9SAndroid Build Coastguard Worker for param_value in combination: 348*da0073e9SAndroid Build Coastguard Worker default_params_copy[param_value[0]] = param_value[2] 349*da0073e9SAndroid Build Coastguard Worker if len(param_value[1]) > 0: 350*da0073e9SAndroid Build Coastguard Worker variant_name = f"{variant_name}_{param_value[1]}" 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker default_params_copy["NAME"] = variant_name 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker self.shader_template_params[template_name].append( 355*da0073e9SAndroid Build Coastguard Worker default_params_copy 356*da0073e9SAndroid Build Coastguard Worker ) 357*da0073e9SAndroid Build Coastguard Worker else: 358*da0073e9SAndroid Build Coastguard Worker default_params_copy = copy.deepcopy(default_params) 359*da0073e9SAndroid Build Coastguard Worker for key in variant: 360*da0073e9SAndroid Build Coastguard Worker default_params_copy[key] = variant[key] 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker self.shader_template_params[template_name].append( 363*da0073e9SAndroid Build Coastguard Worker default_params_copy 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def create_shader_params( 367*da0073e9SAndroid Build Coastguard Worker self, variant_params: dict[str, Any] | None = None 368*da0073e9SAndroid Build Coastguard Worker ) -> dict[str, str]: 369*da0073e9SAndroid Build Coastguard Worker if variant_params is None: 370*da0073e9SAndroid Build Coastguard Worker variant_params = {} 371*da0073e9SAndroid Build Coastguard Worker shader_params = copy.deepcopy(self.env) 372*da0073e9SAndroid Build Coastguard Worker for key, value in variant_params.items(): 373*da0073e9SAndroid Build Coastguard Worker shader_params[key] = value 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker shader_dtype = shader_params.get("DTYPE", "float") 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker if shader_dtype == "int": 378*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"] 379*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "uint": 380*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"] 381*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "int32": 382*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = "rgba32i" 383*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "uint32": 384*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = "rgba32ui" 385*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "int8": 386*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = "rgba8i" 387*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "uint8": 388*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = "rgba8ui" 389*da0073e9SAndroid Build Coastguard Worker elif shader_dtype == "float32": 390*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = "rgba32f" 391*da0073e9SAndroid Build Coastguard Worker # Assume float by default 392*da0073e9SAndroid Build Coastguard Worker else: 393*da0073e9SAndroid Build Coastguard Worker shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"] 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker return shader_params 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def constructOutputMap(self) -> None: 398*da0073e9SAndroid Build Coastguard Worker for shader_name, params in self.shader_template_params.items(): 399*da0073e9SAndroid Build Coastguard Worker for variant in params: 400*da0073e9SAndroid Build Coastguard Worker source_glsl = self.glsl_src_files[shader_name] 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker self.output_shader_map[variant["NAME"]] = ( 403*da0073e9SAndroid Build Coastguard Worker source_glsl, 404*da0073e9SAndroid Build Coastguard Worker self.create_shader_params(variant), 405*da0073e9SAndroid Build Coastguard Worker ) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker for shader_name, source_glsl in self.glsl_src_files.items(): 408*da0073e9SAndroid Build Coastguard Worker if shader_name not in self.shader_template_params: 409*da0073e9SAndroid Build Coastguard Worker self.output_shader_map[shader_name] = ( 410*da0073e9SAndroid Build Coastguard Worker source_glsl, 411*da0073e9SAndroid Build Coastguard Worker self.create_shader_params(), 412*da0073e9SAndroid Build Coastguard Worker ) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker def generateSPV(self, output_dir: str) -> dict[str, str]: 415*da0073e9SAndroid Build Coastguard Worker output_file_map = {} 416*da0073e9SAndroid Build Coastguard Worker for shader_name in self.output_shader_map: 417*da0073e9SAndroid Build Coastguard Worker source_glsl = self.output_shader_map[shader_name][0] 418*da0073e9SAndroid Build Coastguard Worker shader_params = self.output_shader_map[shader_name][1] 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: 421*da0073e9SAndroid Build Coastguard Worker input_text = input_file.read() 422*da0073e9SAndroid Build Coastguard Worker output_text = preprocess(input_text, shader_params) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") 425*da0073e9SAndroid Build Coastguard Worker with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file: 426*da0073e9SAndroid Build Coastguard Worker output_file.write(output_text) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker # If no GLSL compiler is specified, then only write out the generated GLSL shaders. 429*da0073e9SAndroid Build Coastguard Worker # This is mainly for testing purposes. 430*da0073e9SAndroid Build Coastguard Worker if self.glslc_path is not None: 431*da0073e9SAndroid Build Coastguard Worker spv_out_path = os.path.join(output_dir, f"{shader_name}.spv") 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker cmd = [ 434*da0073e9SAndroid Build Coastguard Worker self.glslc_path, 435*da0073e9SAndroid Build Coastguard Worker "-fshader-stage=compute", 436*da0073e9SAndroid Build Coastguard Worker glsl_out_path, 437*da0073e9SAndroid Build Coastguard Worker "-o", 438*da0073e9SAndroid Build Coastguard Worker spv_out_path, 439*da0073e9SAndroid Build Coastguard Worker "--target-env=vulkan1.0", 440*da0073e9SAndroid Build Coastguard Worker "-Werror", 441*da0073e9SAndroid Build Coastguard Worker ] + [ 442*da0073e9SAndroid Build Coastguard Worker arg 443*da0073e9SAndroid Build Coastguard Worker for src_dir_path in self.src_dir_paths 444*da0073e9SAndroid Build Coastguard Worker for arg in ["-I", src_dir_path] 445*da0073e9SAndroid Build Coastguard Worker ] 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker print("glslc cmd:", cmd) 448*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker output_file_map[spv_out_path] = glsl_out_path 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker return output_file_map 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker############################################## 456*da0073e9SAndroid Build Coastguard Worker# Shader Info and Shader Registry Handling # 457*da0073e9SAndroid Build Coastguard Worker############################################## 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker@dataclass 461*da0073e9SAndroid Build Coastguard Workerclass ShaderInfo: 462*da0073e9SAndroid Build Coastguard Worker tile_size: list[int] 463*da0073e9SAndroid Build Coastguard Worker layouts: list[str] 464*da0073e9SAndroid Build Coastguard Worker weight_storage_type: str = "" 465*da0073e9SAndroid Build Coastguard Worker bias_storage_type: str = "" 466*da0073e9SAndroid Build Coastguard Worker register_for: tuple[str, list[str]] | None = None 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Workerdef getName(filePath: str) -> str: 470*da0073e9SAndroid Build Coastguard Worker return os.path.basename(filePath).replace("/", "_").replace(".", "_") 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Workerdef isDescriptorLine(lineStr: str) -> bool: 474*da0073e9SAndroid Build Coastguard Worker descriptorLineId = r"^layout\(set" 475*da0073e9SAndroid Build Coastguard Worker return re.search(descriptorLineId, lineStr) is not None 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Workerdef isTileSizeLine(lineStr: str) -> bool: 479*da0073e9SAndroid Build Coastguard Worker tile_size_id = r"^ \* TILE_SIZE = \(" 480*da0073e9SAndroid Build Coastguard Worker return re.search(tile_size_id, lineStr) is not None 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Workerdef findTileSizes(lineStr: str) -> list[int]: 484*da0073e9SAndroid Build Coastguard Worker tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" 485*da0073e9SAndroid Build Coastguard Worker matches = re.search(tile_size_id, lineStr) 486*da0073e9SAndroid Build Coastguard Worker if matches is None: 487*da0073e9SAndroid Build Coastguard Worker raise AssertionError("matches is None in findTileSizes") 488*da0073e9SAndroid Build Coastguard Worker return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Workerdef isWeightStorageTypeLine(lineStr: str) -> bool: 492*da0073e9SAndroid Build Coastguard Worker weight_storage_id = r"^ \* WEIGHT_STORAGE = " 493*da0073e9SAndroid Build Coastguard Worker return re.search(weight_storage_id, lineStr) is not None 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Workerdef getWeightStorageType(lineStr: str) -> str: 497*da0073e9SAndroid Build Coastguard Worker weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" 498*da0073e9SAndroid Build Coastguard Worker matches = re.search(weight_storage_id, lineStr) 499*da0073e9SAndroid Build Coastguard Worker if matches is None: 500*da0073e9SAndroid Build Coastguard Worker raise AssertionError("matches is None in getWeightStorageType") 501*da0073e9SAndroid Build Coastguard Worker return matches.group(1) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Workerdef isBiasStorageTypeLine(lineStr: str) -> bool: 505*da0073e9SAndroid Build Coastguard Worker weight_storage_id = r"^ \* BIAS_STORAGE = " 506*da0073e9SAndroid Build Coastguard Worker return re.search(weight_storage_id, lineStr) is not None 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Workerdef getBiasStorageType(lineStr: str) -> str: 510*da0073e9SAndroid Build Coastguard Worker weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" 511*da0073e9SAndroid Build Coastguard Worker matches = re.search(weight_storage_id, lineStr) 512*da0073e9SAndroid Build Coastguard Worker if matches is None: 513*da0073e9SAndroid Build Coastguard Worker raise AssertionError("matches is None in getBiasStorageType") 514*da0073e9SAndroid Build Coastguard Worker return matches.group(1) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Workerdef isRegisterForLine(lineStr: str) -> bool: 518*da0073e9SAndroid Build Coastguard Worker # Check for Shader Name and a list of at least one Registry Key 519*da0073e9SAndroid Build Coastguard Worker register_for_id = ( 520*da0073e9SAndroid Build Coastguard Worker r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)" 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker return re.search(register_for_id, lineStr) is not None 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Workerdef findRegisterFor(lineStr: str) -> tuple[str, list[str]]: 526*da0073e9SAndroid Build Coastguard Worker register_for_pattern = r"'([A-Za-z0-9_]+)'" 527*da0073e9SAndroid Build Coastguard Worker matches = re.findall(register_for_pattern, lineStr) 528*da0073e9SAndroid Build Coastguard Worker if matches is None: 529*da0073e9SAndroid Build Coastguard Worker raise AssertionError("matches is None in getBiasStorageType") 530*da0073e9SAndroid Build Coastguard Worker matches_list = list(matches) 531*da0073e9SAndroid Build Coastguard Worker return (matches_list[0], matches_list[1:]) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard WorkertypeIdMapping = { 535*da0073e9SAndroid Build Coastguard Worker r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", 536*da0073e9SAndroid Build Coastguard Worker r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", 537*da0073e9SAndroid Build Coastguard Worker r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", 538*da0073e9SAndroid Build Coastguard Worker r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER", 539*da0073e9SAndroid Build Coastguard Worker} 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard WorkerstorageTypeToEnum = { 542*da0073e9SAndroid Build Coastguard Worker "TEXTURE_2D": "api::StorageType::TEXTURE_2D", 543*da0073e9SAndroid Build Coastguard Worker "TEXTURE_3D": "api::StorageType::TEXTURE_3D", 544*da0073e9SAndroid Build Coastguard Worker "BUFFER": "api::StorageType::BUFFER", 545*da0073e9SAndroid Build Coastguard Worker "": "api::StorageType::UNKNOWN", 546*da0073e9SAndroid Build Coastguard Worker} 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Workerdef determineDescriptorType(lineStr: str) -> str: 550*da0073e9SAndroid Build Coastguard Worker for identifier, typeNum in typeIdMapping.items(): 551*da0073e9SAndroid Build Coastguard Worker if re.search(identifier, lineStr): 552*da0073e9SAndroid Build Coastguard Worker return typeNum 553*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 554*da0073e9SAndroid Build Coastguard Worker "No matching descriptor type for " + lineStr + " in determineDescriptorType" 555*da0073e9SAndroid Build Coastguard Worker ) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Workerdef getShaderInfo(srcFilePath: str) -> ShaderInfo: 559*da0073e9SAndroid Build Coastguard Worker shader_info = ShaderInfo([], [], "") 560*da0073e9SAndroid Build Coastguard Worker with open(srcFilePath) as srcFile: 561*da0073e9SAndroid Build Coastguard Worker for line in srcFile: 562*da0073e9SAndroid Build Coastguard Worker if isDescriptorLine(line): 563*da0073e9SAndroid Build Coastguard Worker shader_info.layouts.append(determineDescriptorType(line)) 564*da0073e9SAndroid Build Coastguard Worker if isTileSizeLine(line): 565*da0073e9SAndroid Build Coastguard Worker shader_info.tile_size = findTileSizes(line) 566*da0073e9SAndroid Build Coastguard Worker if isWeightStorageTypeLine(line): 567*da0073e9SAndroid Build Coastguard Worker shader_info.weight_storage_type = getWeightStorageType(line) 568*da0073e9SAndroid Build Coastguard Worker if isBiasStorageTypeLine(line): 569*da0073e9SAndroid Build Coastguard Worker shader_info.bias_storage_type = getBiasStorageType(line) 570*da0073e9SAndroid Build Coastguard Worker if isRegisterForLine(line): 571*da0073e9SAndroid Build Coastguard Worker shader_info.register_for = findRegisterFor(line) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker return shader_info 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker########################## 577*da0073e9SAndroid Build Coastguard Worker# C++ File Generation # 578*da0073e9SAndroid Build Coastguard Worker######################### 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Workercpp_template = """ 581*da0073e9SAndroid Build Coastguard Worker#include <ATen/native/vulkan/api/ShaderRegistry.h> 582*da0073e9SAndroid Build Coastguard Worker#include <stdint.h> 583*da0073e9SAndroid Build Coastguard Worker#include <vector> 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Workerusing namespace at::native::vulkan; 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Workernamespace at {{ 588*da0073e9SAndroid Build Coastguard Workernamespace native {{ 589*da0073e9SAndroid Build Coastguard Workernamespace vulkan {{ 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Workernamespace {{ 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker{spv_bin_arrays} 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker}} 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Workerstatic void register_fn() {{ 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker{register_shader_infos} 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker{shader_info_registry} 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker}} 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Workerstatic const api::ShaderRegisterInit register_shaders(®ister_fn); 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker}} 608*da0073e9SAndroid Build Coastguard Worker}} 609*da0073e9SAndroid Build Coastguard Worker}} 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker""" 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Workerdef generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]: 615*da0073e9SAndroid Build Coastguard Worker with open(spvPath, "rb") as fr: 616*da0073e9SAndroid Build Coastguard Worker next_bin = array.array("I", fr.read()) 617*da0073e9SAndroid Build Coastguard Worker sizeBytes = 4 * len(next_bin) 618*da0073e9SAndroid Build Coastguard Worker spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format( 619*da0073e9SAndroid Build Coastguard Worker name, 620*da0073e9SAndroid Build Coastguard Worker textwrap.indent(",\n".join(str(x) for x in next_bin), " "), 621*da0073e9SAndroid Build Coastguard Worker ) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker return sizeBytes, spv_bin_str 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Workerdef generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str: 627*da0073e9SAndroid Build Coastguard Worker tile_size = ( 628*da0073e9SAndroid Build Coastguard Worker f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}" 629*da0073e9SAndroid Build Coastguard Worker if (len(shader_info.tile_size) > 0) 630*da0073e9SAndroid Build Coastguard Worker else "std::vector<uint32_t>()" 631*da0073e9SAndroid Build Coastguard Worker ) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker shader_info_args = [ 636*da0073e9SAndroid Build Coastguard Worker f'"{name}"', 637*da0073e9SAndroid Build Coastguard Worker f"{name}_bin", 638*da0073e9SAndroid Build Coastguard Worker str(sizeBytes), 639*da0073e9SAndroid Build Coastguard Worker shader_info_layouts, 640*da0073e9SAndroid Build Coastguard Worker tile_size, 641*da0073e9SAndroid Build Coastguard Worker storageTypeToEnum[shader_info.weight_storage_type], 642*da0073e9SAndroid Build Coastguard Worker storageTypeToEnum[shader_info.bias_storage_type], 643*da0073e9SAndroid Build Coastguard Worker ] 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker shader_info_str = textwrap.indent( 646*da0073e9SAndroid Build Coastguard Worker "api::shader_registry().register_shader(\n api::ShaderInfo(\n{args}));\n".format( 647*da0073e9SAndroid Build Coastguard Worker args=textwrap.indent(",\n".join(shader_info_args), " "), 648*da0073e9SAndroid Build Coastguard Worker ), 649*da0073e9SAndroid Build Coastguard Worker " ", 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker return shader_info_str 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Workerdef generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str: 656*da0073e9SAndroid Build Coastguard Worker if shader_info.register_for is None: 657*da0073e9SAndroid Build Coastguard Worker return "" 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker (op_name, registry_keys) = shader_info.register_for 660*da0073e9SAndroid Build Coastguard Worker for registry_key in registry_keys: 661*da0073e9SAndroid Build Coastguard Worker shader_dispatch_str = textwrap.indent( 662*da0073e9SAndroid Build Coastguard Worker f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");', 663*da0073e9SAndroid Build Coastguard Worker " ", 664*da0073e9SAndroid Build Coastguard Worker ) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker return shader_dispatch_str 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Workerdef genCppFiles( 670*da0073e9SAndroid Build Coastguard Worker spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str 671*da0073e9SAndroid Build Coastguard Worker) -> None: 672*da0073e9SAndroid Build Coastguard Worker spv_bin_strs = [] 673*da0073e9SAndroid Build Coastguard Worker register_shader_info_strs = [] 674*da0073e9SAndroid Build Coastguard Worker shader_registry_strs = [] 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker for spvPath, srcPath in spv_files.items(): 677*da0073e9SAndroid Build Coastguard Worker name = getName(spvPath).replace("_spv", "") 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name) 680*da0073e9SAndroid Build Coastguard Worker spv_bin_strs.append(spv_bin_str) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker shader_info = getShaderInfo(srcPath) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker register_shader_info_strs.append( 685*da0073e9SAndroid Build Coastguard Worker generateShaderInfoStr(shader_info, name, sizeBytes) 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker if shader_info.register_for is not None: 689*da0073e9SAndroid Build Coastguard Worker shader_registry_strs.append(generateShaderDispatchStr(shader_info, name)) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker spv_bin_arrays = "\n".join(spv_bin_strs) 692*da0073e9SAndroid Build Coastguard Worker register_shader_infos = "\n".join(register_shader_info_strs) 693*da0073e9SAndroid Build Coastguard Worker shader_info_registry = "\n".join(shader_registry_strs) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker cpp = cpp_template.format( 696*da0073e9SAndroid Build Coastguard Worker spv_bin_arrays=spv_bin_arrays, 697*da0073e9SAndroid Build Coastguard Worker register_shader_infos=register_shader_infos, 698*da0073e9SAndroid Build Coastguard Worker shader_info_registry=shader_info_registry, 699*da0073e9SAndroid Build Coastguard Worker ) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker with open(cpp_src_file_path, "w") as fw: 702*da0073e9SAndroid Build Coastguard Worker fw.write(cpp) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker########## 706*da0073e9SAndroid Build Coastguard Worker# Main # 707*da0073e9SAndroid Build Coastguard Worker########## 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Workerdef parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]: 711*da0073e9SAndroid Build Coastguard Worker d = {} 712*da0073e9SAndroid Build Coastguard Worker if items: 713*da0073e9SAndroid Build Coastguard Worker for item in items: 714*da0073e9SAndroid Build Coastguard Worker tokens = item.split("=") 715*da0073e9SAndroid Build Coastguard Worker key = tokens[0].strip() 716*da0073e9SAndroid Build Coastguard Worker value = tokens[1].strip() 717*da0073e9SAndroid Build Coastguard Worker d[key] = value 718*da0073e9SAndroid Build Coastguard Worker return d 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Workerdef main(argv: list[str]) -> int: 722*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="") 723*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 724*da0073e9SAndroid Build Coastguard Worker "-i", 725*da0073e9SAndroid Build Coastguard Worker "--glsl-paths", 726*da0073e9SAndroid Build Coastguard Worker nargs="+", 727*da0073e9SAndroid Build Coastguard Worker help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"', 728*da0073e9SAndroid Build Coastguard Worker default=["."], 729*da0073e9SAndroid Build Coastguard Worker ) 730*da0073e9SAndroid Build Coastguard Worker parser.add_argument("-c", "--glslc-path", required=True, help="") 731*da0073e9SAndroid Build Coastguard Worker parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") 732*da0073e9SAndroid Build Coastguard Worker parser.add_argument("-o", "--output-path", required=True, help="") 733*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 734*da0073e9SAndroid Build Coastguard Worker "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs" 735*da0073e9SAndroid Build Coastguard Worker ) 736*da0073e9SAndroid Build Coastguard Worker options = parser.parse_args() 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker DEFAULT_ENV.update(TYPES_ENV) 739*da0073e9SAndroid Build Coastguard Worker DEFAULT_ENV.update(FUNCS_ENV) 740*da0073e9SAndroid Build Coastguard Worker env = DEFAULT_ENV 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker for key, value in parse_arg_env(options.env).items(): 743*da0073e9SAndroid Build Coastguard Worker env[key] = value 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(options.output_path): 746*da0073e9SAndroid Build Coastguard Worker os.makedirs(options.output_path) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(options.tmp_dir_path): 749*da0073e9SAndroid Build Coastguard Worker os.makedirs(options.tmp_dir_path) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path) 752*da0073e9SAndroid Build Coastguard Worker output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker genCppFiles( 755*da0073e9SAndroid Build Coastguard Worker output_spv_files, 756*da0073e9SAndroid Build Coastguard Worker f"{options.output_path}/{CPP_H_NAME}", 757*da0073e9SAndroid Build Coastguard Worker f"{options.output_path}/{CPP_SRC_NAME}", 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker return 0 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Workerdef invoke_main() -> None: 764*da0073e9SAndroid Build Coastguard Worker sys.exit(main(sys.argv)) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 768*da0073e9SAndroid Build Coastguard Worker invoke_main() # pragma: no cover 769