1#!/usr/bin/env python3 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10import argparse 11import array 12import codecs 13import copy 14import glob 15import io 16import os 17import re 18import sys 19from itertools import product 20from multiprocessing.pool import ThreadPool 21 22sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 23import subprocess 24import textwrap 25from dataclasses import dataclass 26from typing import Any, Dict, List, Optional, Set, Tuple, Union 27 28import yaml 29from yaml.constructor import ConstructorError 30from yaml.nodes import MappingNode 31 32try: 33 from yaml import CLoader as Loader 34except ImportError: 35 from yaml import Loader # type: ignore[assignment, misc] 36 37CPP_H_NAME = "spv.h" 38CPP_SRC_NAME = "spv.cpp" 39 40# Basic configuration settings for shaders 41DEFAULT_ENV: Dict[str, Any] = { 42 "PRECISION": "highp", 43 # B is shorthand for "binding". This is used to automatically increment the 44 # layout binding index when declaring layout bindings. Note that a container 45 # type is used because integers are immutable in Python. 46 "B": [0], 47 # C is shorthand for "constant_id". This is used to automatically increment the 48 # constant_id index for specialization constants. 49 # Note that it starts at 3, as 0-2 are reserved for local workgroup size ids. 50 "C": [3], 51} 52 53# Establishes relationships between different tensor types and different GLSL types 54TYPE_MAPPINGS: Dict[str, Any] = { 55 "IMAGE_T": { 56 3: { 57 "float": "image3D", 58 "half": "image3D", 59 "int": "iimage3D", 60 "uint": "uimage3D", 61 "int8": "iimage3D", 62 "uint8": "uimage3D", 63 }, 64 2: { 65 "float": "image2D", 66 "half": "image2D", 67 "int": "iimage2D", 68 "uint": "uimage2D", 69 "int8": "iimage2D", 70 "uint8": "uimage2D", 71 }, 72 }, 73 "SAMPLER_T": { 74 3: { 75 "float": "sampler3D", 76 "half": "sampler3D", 77 "int": "isampler3D", 78 "uint": "usampler3D", 79 "int8": "isampler3D", 80 "uint8": "usampler3D", 81 }, 82 2: { 83 "float": "sampler2D", 84 "half": "sampler2D", 85 "int": "isampler2D", 86 "uint": "usampler2D", 87 "int8": "isampler2D", 88 "uint8": "usampler2D", 89 }, 90 }, 91 "IMAGE_FORMAT": { 92 "float": "rgba32f", 93 "half": "rgba16f", 94 "int": "rgba32i", 95 "uint": "rgba32ui", 96 "int8": "rgba8i", 97 "uint8": "rgba8ui", 98 }, 99} 100 101 102def define_variable(name: str) -> str: 103 if name in locals(): 104 return f"#define {name} {locals()[name]}" 105 elif name in globals(): 106 return f"#define {name} {globals()[name]}" 107 else: 108 raise RuntimeError(f"{name} is not defined") 109 110 111def buffer_scalar_type(dtype: str) -> str: 112 if dtype == "half": 113 return "float16_t" 114 elif dtype[-1] == "8": 115 return dtype + "_t" 116 117 return dtype 118 119 120def buffer_gvec_type(dtype: str, n: int) -> str: 121 if n == 1: 122 return buffer_scalar_type(dtype) 123 124 if dtype == "float": 125 return f"vec{n}" 126 elif dtype == "half": 127 return f"f16vec{n}" 128 elif dtype == "int": 129 return f"ivec{n}" 130 elif dtype == "int8": 131 return f"i8vec{n}" 132 elif dtype == "uint8": 133 return f"u8vec{n}" 134 135 raise AssertionError(f"Invalid dtype: {dtype}") 136 137 138def texel_type(dtype: str) -> str: 139 image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] 140 if image_format[-1] == "f": 141 return "vec4" 142 elif image_format[-2] == "ui": 143 return "uvec4" 144 elif image_format[-1] == "i": 145 return "ivec4" 146 raise AssertionError(f"Invalid image format: {image_format}") 147 148 149def gvec_type(dtype: str, n: int) -> str: 150 gvec4_type = texel_type(dtype) 151 return gvec4_type[:-1] + str(n) 152 153 154def texel_component_type(dtype: str) -> str: 155 vec4_type = texel_type(dtype) 156 if vec4_type[:3] == "vec": 157 return "float" 158 elif vec4_type[:4] == "ivec": 159 return "int" 160 elif vec4_type[:4] == "uvec": 161 return "uint" 162 raise AssertionError(f"Invalid vec4 type: {vec4_type}") 163 164 165def texel_load_type(dtype: str, storage_type: str) -> str: 166 if storage_type.lower() == "buffer": 167 return buffer_gvec_type(dtype, 4) 168 else: 169 return texel_type(dtype) 170 171 172def texel_load_component_type(dtype: str, storage_type: str) -> str: 173 if storage_type.lower() == "buffer": 174 return buffer_scalar_type(dtype) 175 else: 176 return texel_component_type(dtype) 177 178 179def get_access_qualifier(access_type: Optional[str]) -> str: 180 if access_type is None: 181 return "" 182 if access_type.lower() == "r": 183 return "readonly" 184 if access_type.lower() == "w": 185 return "writeonly" 186 if access_type.lower() == "rw": 187 return "" 188 189 raise AssertionError(f"Invalid access type: {access_type}") 190 191 192def get_slot_val(slot: Union[int, List[int]]) -> int: 193 if isinstance(slot, list): 194 return slot[0] 195 return slot 196 197 198def layout_declare_buffer( 199 slot: Union[int, List[int]], 200 access_type: str, 201 var_name: str, 202 dtype: str, 203 precision: str = "PRECISION", 204 is_scalar_array: bool = True, 205) -> str: 206 array_type = buffer_gvec_type(dtype, 4) 207 if is_scalar_array: 208 array_type = buffer_scalar_type(dtype) 209 210 out_str = f""" 211layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{ 212 {array_type} {var_name}[]; 213}}; 214""" 215 216 if isinstance(slot, list): 217 slot[0] = slot[0] + 1 218 return out_str 219 220 221def layout_declare_image( 222 slot: Union[int, List[int]], 223 access_type: str, 224 var_name: str, 225 dtype: str, 226 precision: str = "PRECISION", 227 image_ndim: int = 3, 228) -> str: 229 image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] 230 image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype] 231 232 ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};" 233 234 if isinstance(slot, list): 235 slot[0] = slot[0] + 1 236 return ret_str 237 238 239def layout_declare_sampler( 240 slot: Union[int, List[int]], 241 access_type: str, 242 var_name: str, 243 dtype: str, 244 precision: str = "PRECISION", 245 access_qualifier: Optional[str] = None, 246 image_ndim: int = 3, 247) -> str: 248 sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype] 249 250 ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};" 251 252 if isinstance(slot, list): 253 slot[0] = slot[0] + 1 254 return ret_str 255 256 257def layout_declare_tensor( 258 slot: Union[int, List[int]], 259 access_type: str, 260 var_name: str, 261 dtype: str, 262 storage_type: str, 263 is_scalar_array: bool = True, 264 precision: str = "PRECISION", 265) -> str: 266 assert storage_type.lower() in ["buffer", "texture3d", "texture2d"] 267 268 image_ndim = 3 269 if storage_type.lower() == "texture2d": 270 image_ndim = 2 271 272 # Create buffer binding 273 if storage_type.lower() == "buffer": 274 return layout_declare_buffer( 275 slot, 276 access_type, 277 var_name, 278 dtype, 279 precision, 280 is_scalar_array=is_scalar_array, 281 ) 282 283 # Create image/sampler binding 284 if access_type.lower() == "r": 285 return layout_declare_sampler( 286 slot, access_type, var_name, dtype, precision, image_ndim=image_ndim 287 ) 288 else: 289 return layout_declare_image( 290 slot, access_type, var_name, dtype, precision, image_ndim=image_ndim 291 ) 292 293 294def layout_declare_ubo( 295 slot: Union[int, List[int]], *args, precision: str = "PRECISION" 296) -> str: 297 assert len(args) % 2 == 0 298 299 var_list = list(zip(args[::2], args[1::2])) 300 301 ubo_name = "" 302 for _, var_name in var_list: 303 ubo_name += var_name + "_" 304 305 out_str = f""" 306layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{ 307""" 308 for type_name, var_name in var_list: 309 out_str += f" {type_name} {var_name};\n" 310 out_str += "};" 311 312 if isinstance(slot, list): 313 slot[0] = slot[0] + 1 314 return out_str 315 316 317def layout_declare_spec_const( 318 slot: Union[int, List[int]], 319 type_name: str, 320 var_name: str, 321 initial_val: Optional[str] = None, 322) -> str: 323 assert type_name in ["int", "uint", "float", "bool"] 324 325 out_str = f"layout(constant_id = {get_slot_val(slot)}) const {type_name} {var_name}" 326 if initial_val is not None: 327 out_str += f" = {initial_val}" 328 out_str += ";" 329 330 if isinstance(slot, list): 331 slot[0] = slot[0] + 1 332 return out_str 333 334 335def define_active_storage_type(storage_type: str): 336 if storage_type.lower() == "buffer": 337 return "#define USING_BUFFER" 338 elif storage_type.lower() == "texture3d": 339 return "#define USING_TEXTURE3D" 340 elif storage_type.lower() == "texture2d": 341 return "#define USING_TEXTURE2D" 342 else: 343 raise AssertionError(f"Invalid storage type: {storage_type}") 344 345 346def define_required_extensions(dtypes: Union[str, List[str]]): 347 out_str = "\n" 348 dtype_list = dtypes if isinstance(dtypes, list) else [dtypes] 349 350 for dtype in dtype_list: 351 nbit = None 352 glsl_type = None 353 if dtype == "half": 354 nbit = "16bit" 355 glsl_type = "float16" 356 elif dtype == "int16" or dtype == "uint16": 357 nbit = "16bit" 358 glsl_type = "int16" 359 elif dtype == "int8" or dtype == "uint8": 360 nbit = "8bit" 361 glsl_type = "int8" 362 363 if nbit is not None and glsl_type is not None: 364 out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" 365 out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" 366 367 return out_str 368 369 370UTILITY_FNS: Dict[str, Any] = { 371 "macro_define": define_variable, 372 "get_pos": { 373 3: lambda pos: pos, 374 2: lambda pos: f"{pos}.xy", 375 }, 376 "buffer_scalar_type": buffer_scalar_type, 377 "buffer_gvec_type": buffer_gvec_type, 378 "texel_type": texel_type, 379 "gvec_type": gvec_type, 380 "texel_component_type": texel_component_type, 381 "texel_load_type": texel_load_type, 382 "texel_load_component_type": texel_load_component_type, 383 "layout_declare_buffer": layout_declare_buffer, 384 "layout_declare_image": layout_declare_image, 385 "layout_declare_sampler": layout_declare_sampler, 386 "layout_declare_tensor": layout_declare_tensor, 387 "layout_declare_ubo": layout_declare_ubo, 388 "layout_declare_spec_const": layout_declare_spec_const, 389 "define_active_storage_type": define_active_storage_type, 390 "define_required_extensions": define_required_extensions, 391} 392 393 394def extract_filename(path: str, keep_ext: bool = True) -> Any: 395 if keep_ext: 396 return os.path.basename(path) 397 else: 398 return os.path.basename(path).split(".")[0] 399 400 401############################ 402# SPIR-V Code Generation # 403############################ 404 405 406# https://gist.github.com/pypt/94d747fe5180851196eb 407class UniqueKeyLoader(Loader): 408 def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] 409 if not isinstance(node, MappingNode): 410 raise ConstructorError( 411 None, 412 None, 413 f"expected a mapping node, but found {node.id}", 414 node.start_mark, 415 ) 416 mapping = {} 417 for key_node, value_node in node.value: 418 key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] 419 try: 420 hash(key) 421 except TypeError as e: 422 raise ConstructorError( 423 "while constructing a mapping", 424 node.start_mark, 425 "found unacceptable key ", 426 key_node.start_mark, 427 ) from e 428 # check for duplicate keys 429 if key in mapping: 430 raise ConstructorError( 431 "while constructing a mapping", 432 node.start_mark, 433 "found duplicate key", 434 key_node.start_mark, 435 ) 436 value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] 437 mapping[key] = value 438 return mapping 439 440 441# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 442def extract_leading_whitespace(line: str) -> str: 443 match = re.match(r"\s*", line) 444 return match.group(0) if match else "" 445 446 447# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 448def escape(line: str) -> str: 449 output_parts = [] 450 while "${" in line: 451 start_pos = line.index("${") 452 end_pos = line.index("}", start_pos + 2) 453 if start_pos != 0: 454 output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"') 455 output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")") 456 line = line[end_pos + 1 :] 457 if line: 458 output_parts.append('"' + line.replace('"', '\\"') + '"') 459 return " + ".join(output_parts) 460 461 462# https://github.com/google/XNNPACK/blob/master/tools/xngen.py 463def preprocess( 464 input_text: str, variables: Dict[str, Any], input_path: str = "codegen" 465) -> str: 466 input_lines = input_text.splitlines() 467 python_lines = [] 468 469 blank_lines = 0 470 471 last_indent = "" 472 473 # List of tuples (total_index, python_indent) 474 indent_stack = [("", "")] 475 476 # Indicates whether this is the first line inside Python 477 # code block (i.e. for, while, if, elif, else) 478 python_block_start = True 479 for input_line in input_lines: 480 if input_line == "": 481 blank_lines += 1 482 continue 483 # Skip lint markers. 484 if "LINT" in input_line: 485 continue 486 487 input_indent = extract_leading_whitespace(input_line) 488 if python_block_start: 489 assert input_indent.startswith(last_indent) 490 extra_python_indent = input_indent[len(last_indent) :] 491 python_indent = indent_stack[-1][1] + extra_python_indent 492 indent_stack.append((input_indent, python_indent)) 493 assert input_indent.startswith(indent_stack[-1][0]) 494 else: 495 while not input_indent.startswith(indent_stack[-1][0]): 496 del indent_stack[-1] 497 python_block_start = False 498 499 python_indent = indent_stack[-1][1] 500 stripped_input_line = input_line.strip() 501 if stripped_input_line.startswith("$") and not stripped_input_line.startswith( 502 "${" 503 ): 504 if stripped_input_line.endswith(":"): 505 python_block_start = True 506 while blank_lines != 0: 507 python_lines.append(python_indent + "print(file=OUT_STREAM)") 508 blank_lines -= 1 509 python_lines.append(python_indent + stripped_input_line.replace("$", "")) 510 else: 511 assert input_line.startswith(python_indent) 512 while blank_lines != 0: 513 python_lines.append(python_indent + "print(file=OUT_STREAM)") 514 blank_lines -= 1 515 python_lines.append( 516 python_indent 517 + "print(%s, file=OUT_STREAM)" 518 % escape(input_line[len(python_indent) :]) 519 ) 520 last_indent = input_indent 521 522 while blank_lines != 0: 523 python_lines.append(python_indent + "print(file=OUT_STREAM)") 524 blank_lines -= 1 525 526 exec_globals = dict(variables) 527 output_stream = io.StringIO() 528 exec_globals["OUT_STREAM"] = output_stream 529 530 python_bytecode = compile("\n".join(python_lines), input_path, "exec") 531 exec(python_bytecode, exec_globals) 532 533 return output_stream.getvalue() 534 535 536class SPVGenerator: 537 def __init__( 538 self, 539 src_dir_paths: Union[str, List[str]], 540 env: Dict[Any, Any], 541 glslc_path: Optional[str], 542 glslc_flags: str = "", 543 replace_u16vecn: bool = False, 544 ) -> None: 545 if isinstance(src_dir_paths, str): 546 self.src_dir_paths = [src_dir_paths] 547 else: 548 self.src_dir_paths = src_dir_paths 549 550 self.env = env 551 self.glslc_path = glslc_path 552 self.glslc_flags = glslc_flags 553 self.replace_u16vecn = replace_u16vecn 554 555 self.glsl_src_files: Dict[str, str] = {} 556 self.template_yaml_files: List[str] = [] 557 558 self.addSrcAndYamlFiles(self.src_dir_paths) 559 self.shader_template_params: Dict[Any, Any] = {} 560 for yaml_file in self.template_yaml_files: 561 self.parseTemplateYaml(yaml_file) 562 563 self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {} 564 self.constructOutputMap() 565 566 def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: 567 for src_path in src_dir_paths: 568 # Collect glsl source files 569 glsl_files = glob.glob( 570 os.path.join(src_path, "**", "*.glsl*"), recursive=True 571 ) 572 for file in glsl_files: 573 if len(file) > 1: 574 self.glsl_src_files[extract_filename(file, keep_ext=False)] = file 575 # Collect template yaml files 576 yaml_files = glob.glob( 577 os.path.join(src_path, "**", "*.yaml"), recursive=True 578 ) 579 for file in yaml_files: 580 if len(file) > 1: 581 self.template_yaml_files.append(file) 582 583 def generateVariantCombinations( 584 self, 585 iterated_params: Dict[str, Any], 586 exclude_params: Optional[Set[str]] = None, 587 ) -> List[Any]: 588 if exclude_params is None: 589 exclude_params = set() 590 all_iterated_params = [] 591 for param_name, value_list in iterated_params.items(): 592 if param_name not in exclude_params: 593 param_values = [] 594 for value in value_list: 595 if "RANGE" in value: 596 value_range = value["RANGE"] 597 suffix = value.get("SUFFIX", "") 598 if isinstance(value_range, list) and len(value_range) == 2: 599 for i in range(value_range[0], value_range[1] + 1): 600 curr_suffix = ( 601 suffix + "_" + str(i) if suffix else str(i) 602 ) 603 param_values.append((param_name, curr_suffix, i)) 604 else: 605 raise ValueError( 606 f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)." 607 ) 608 609 elif "VALUE" in value: 610 suffix = value.get("SUFFIX", value["VALUE"]) 611 param_values.append((param_name, suffix, value["VALUE"])) 612 613 else: 614 raise KeyError( 615 "Parameter must be 'VALUE: string' or 'RANGE: [a, b]'" 616 ) 617 618 all_iterated_params.append(param_values) 619 620 return list(product(*all_iterated_params)) 621 622 def parseTemplateYaml(self, yaml_file: str) -> None: 623 with open(yaml_file) as f: 624 contents = yaml.load(f, Loader=UniqueKeyLoader) 625 for template_name, params_dict in contents.items(): 626 if template_name in self.shader_template_params: 627 raise KeyError(f"{template_name} params file is defined twice") 628 629 default_params = params_dict["parameter_names_with_default_values"] 630 params_names = set(default_params.keys()).union({"NAME"}) 631 632 self.shader_template_params[template_name] = [] 633 634 default_iterated_params = params_dict.get( 635 "generate_variant_forall", None 636 ) 637 638 for variant in params_dict["shader_variants"]: 639 variant_params_names = set(variant.keys()) 640 invalid_keys = ( 641 variant_params_names 642 - params_names 643 - {"generate_variant_forall"} 644 ) 645 assert len(invalid_keys) == 0 646 647 iterated_params = variant.get( 648 "generate_variant_forall", default_iterated_params 649 ) 650 651 if iterated_params is not None: 652 variant_combinations = self.generateVariantCombinations( 653 iterated_params, variant_params_names 654 ) 655 656 for combination in variant_combinations: 657 default_params_copy = copy.deepcopy(default_params) 658 for key in variant: 659 if key != "generate_variant_forall": 660 default_params_copy[key] = variant[key] 661 662 variant_name = variant["NAME"] 663 for param_value in combination: 664 default_params_copy[param_value[0]] = param_value[2] 665 if len(str(param_value[1])) > 0: 666 variant_name = f"{variant_name}_{param_value[1]}" 667 668 default_params_copy["NAME"] = variant_name 669 670 self.shader_template_params[template_name].append( 671 default_params_copy 672 ) 673 else: 674 default_params_copy = copy.deepcopy(default_params) 675 for key in variant: 676 default_params_copy[key] = variant[key] 677 678 self.shader_template_params[template_name].append( 679 default_params_copy 680 ) 681 682 def create_shader_params( 683 self, variant_params: Optional[Dict[str, Any]] = None 684 ) -> Dict[str, str]: 685 if variant_params is None: 686 variant_params = {} 687 shader_params = copy.deepcopy(self.env) 688 for key, value in variant_params.items(): 689 shader_params[key] = value 690 691 return shader_params 692 693 def constructOutputMap(self) -> None: 694 for shader_name, params in self.shader_template_params.items(): 695 for variant in params: 696 source_glsl = self.glsl_src_files[shader_name] 697 698 self.output_shader_map[variant["NAME"]] = ( 699 source_glsl, 700 self.create_shader_params(variant), 701 ) 702 703 for shader_name, source_glsl in self.glsl_src_files.items(): 704 if shader_name not in self.shader_template_params: 705 self.output_shader_map[shader_name] = ( 706 source_glsl, 707 self.create_shader_params(), 708 ) 709 710 def maybe_replace_u16vecn(self, input_text: str) -> str: 711 """ 712 There is a latency benefit to using u16vecn variables to store texture position 713 variables instead of ivecn, likely due to reduced register pressure. However, 714 SwiftShader does not support 16 bit integer types in shaders, so this is a crude 715 way to fallback to using ivecn to store texture positions so that testing with 716 SwiftShader is still possible. 717 """ 718 if not self.replace_u16vecn: 719 return input_text 720 if "codegen-nosub" in input_text: 721 return input_text 722 723 input_text = input_text.replace("u16vec", "ivec") 724 input_text = input_text.replace("uint16_t", "int") 725 return input_text 726 727 def generateSPV(self, output_dir: str) -> Dict[str, str]: 728 output_file_map = {} 729 730 def process_shader(shader_paths_pair): 731 shader_name = shader_paths_pair[0] 732 733 source_glsl = shader_paths_pair[1][0] 734 shader_params = shader_paths_pair[1][1] 735 736 with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: 737 input_text = input_file.read() 738 input_text = self.maybe_replace_u16vecn(input_text) 739 output_text = preprocess(input_text, shader_params) 740 741 glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") 742 with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file: 743 output_file.write(output_text) 744 745 # If no GLSL compiler is specified, then only write out the generated GLSL shaders. 746 # This is mainly for testing purposes. 747 if self.glslc_path is not None: 748 spv_out_path = os.path.join(output_dir, f"{shader_name}.spv") 749 750 cmd = ( 751 [ 752 self.glslc_path, 753 "-fshader-stage=compute", 754 glsl_out_path, 755 "-o", 756 spv_out_path, 757 "--target-env=vulkan1.1", 758 "-Werror", 759 ] 760 + [ 761 arg 762 for src_dir_path in self.src_dir_paths 763 for arg in ["-I", src_dir_path] 764 ] 765 + self.glslc_flags.split() 766 ) 767 768 subprocess.check_call(cmd) 769 770 return (spv_out_path, glsl_out_path) 771 772 # Parallelize shader compilation as much as possible to optimize build time. 773 with ThreadPool(os.cpu_count()) as pool: 774 for spv_out_path, glsl_out_path in pool.map( 775 process_shader, self.output_shader_map.items() 776 ): 777 output_file_map[spv_out_path] = glsl_out_path 778 779 return output_file_map 780 781 782############################################## 783# Shader Info and Shader Registry Handling # 784############################################## 785 786 787@dataclass 788class ShaderInfo: 789 tile_size: List[int] 790 layouts: List[str] 791 weight_storage_type: str = "" 792 bias_storage_type: str = "" 793 register_for: Optional[Tuple[str, List[str]]] = None 794 795 796def getName(filePath: str) -> str: 797 return os.path.basename(filePath).replace("/", "_").replace(".", "_") 798 799 800def isDescriptorLine(lineStr: str) -> bool: 801 descriptorLineId = r"^layout\(set" 802 return re.search(descriptorLineId, lineStr) is not None 803 804 805def isTileSizeLine(lineStr: str) -> bool: 806 tile_size_id = r"^ \* TILE_SIZE = \(" 807 return re.search(tile_size_id, lineStr) is not None 808 809 810def findTileSizes(lineStr: str) -> List[int]: 811 tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" 812 matches = re.search(tile_size_id, lineStr) 813 if matches is None: 814 raise AssertionError("matches is None in findTileSizes") 815 return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] 816 817 818def isWeightStorageTypeLine(lineStr: str) -> bool: 819 weight_storage_id = r"^ \* WEIGHT_STORAGE = " 820 return re.search(weight_storage_id, lineStr) is not None 821 822 823def getWeightStorageType(lineStr: str) -> str: 824 weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" 825 matches = re.search(weight_storage_id, lineStr) 826 if matches is None: 827 raise AssertionError("matches is None in getWeightStorageType") 828 return matches.group(1) 829 830 831def isBiasStorageTypeLine(lineStr: str) -> bool: 832 weight_storage_id = r"^ \* BIAS_STORAGE = " 833 return re.search(weight_storage_id, lineStr) is not None 834 835 836def getBiasStorageType(lineStr: str) -> str: 837 weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" 838 matches = re.search(weight_storage_id, lineStr) 839 if matches is None: 840 raise AssertionError("matches is None in getBiasStorageType") 841 return matches.group(1) 842 843 844def isRegisterForLine(lineStr: str) -> bool: 845 # Check for Shader Name and a list of at least one Registry Key 846 register_for_id = ( 847 r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)" 848 ) 849 return re.search(register_for_id, lineStr) is not None 850 851 852def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: 853 register_for_pattern = r"'([A-Za-z0-9_]+)'" 854 matches = re.findall(register_for_pattern, lineStr) 855 if matches is None: 856 raise AssertionError("matches is None in getBiasStorageType") 857 matches_list = list(matches) 858 return (matches_list[0], matches_list[1:]) 859 860 861typeIdMapping = { 862 r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", 863 r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", 864 r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", 865 r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER", 866} 867 868 869def determineDescriptorType(lineStr: str) -> str: 870 for identifier, typeNum in typeIdMapping.items(): 871 if re.search(identifier, lineStr): 872 return typeNum 873 raise AssertionError( 874 "No matching descriptor type for " + lineStr + " in determineDescriptorType" 875 ) 876 877 878def getShaderInfo(srcFilePath: str) -> ShaderInfo: 879 shader_info = ShaderInfo([], [], "") 880 with open(srcFilePath) as srcFile: 881 for line in srcFile: 882 if isDescriptorLine(line): 883 shader_info.layouts.append(determineDescriptorType(line)) 884 if isTileSizeLine(line): 885 shader_info.tile_size = findTileSizes(line) 886 if isWeightStorageTypeLine(line): 887 shader_info.weight_storage_type = getWeightStorageType(line) 888 if isBiasStorageTypeLine(line): 889 shader_info.bias_storage_type = getBiasStorageType(line) 890 if isRegisterForLine(line): 891 shader_info.register_for = findRegisterFor(line) 892 893 return shader_info 894 895 896########################## 897# C++ File Generation # 898######################### 899 900cpp_template = """ 901#include <executorch/backends/vulkan/runtime/api/ShaderRegistry.h> 902#include <stdint.h> 903#include <vector> 904 905using namespace vkcompute; 906 907namespace at {{ 908namespace native {{ 909namespace vulkan {{ 910 911namespace {{ 912 913{spv_bin_arrays} 914 915}} 916 917static void register_fn() {{ 918 919{register_shader_infos} 920 921{shader_info_registry} 922 923}} 924 925static const api::ShaderRegisterInit register_shaders(®ister_fn); 926 927}} 928}} 929}} 930 931""" 932 933 934def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]: 935 with open(spvPath, "rb") as fr: 936 next_bin = array.array("I", fr.read()) 937 sizeBytes = 4 * len(next_bin) 938 spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format( 939 name, 940 textwrap.indent(",\n".join(str(x) for x in next_bin), " "), 941 ) 942 943 return sizeBytes, spv_bin_str 944 945 946def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str: 947 tile_size = ( 948 f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}" 949 if (len(shader_info.tile_size) > 0) 950 else "{1, 1, 1}" 951 ) 952 953 shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) 954 955 shader_info_args = [ 956 f'"{name}"', 957 f"{name}_bin", 958 str(sizeBytes), 959 shader_info_layouts, 960 tile_size, 961 ] 962 963 shader_info_str = textwrap.indent( 964 "api::shader_registry().register_shader(\n vkapi::ShaderInfo(\n{args}));\n".format( 965 args=textwrap.indent(",\n".join(shader_info_args), " "), 966 ), 967 " ", 968 ) 969 970 return shader_info_str 971 972 973def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str: 974 if shader_info.register_for is None: 975 return "" 976 977 (op_name, registry_keys) = shader_info.register_for 978 shader_dispatch_str = "" 979 for registry_key in registry_keys: 980 shader_dispatch_str = textwrap.indent( 981 f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");', 982 " ", 983 ) 984 985 return shader_dispatch_str 986 987 988def genCppFiles( 989 spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str 990) -> None: 991 spv_bin_strs = [] 992 register_shader_info_strs = [] 993 shader_registry_strs = [] 994 995 for spvPath, srcPath in spv_files.items(): 996 name = getName(spvPath).replace("_spv", "") 997 998 sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name) 999 spv_bin_strs.append(spv_bin_str) 1000 1001 shader_info = getShaderInfo(srcPath) 1002 1003 register_shader_info_strs.append( 1004 generateShaderInfoStr(shader_info, name, sizeBytes) 1005 ) 1006 1007 if shader_info.register_for is not None: 1008 shader_registry_strs.append(generateShaderDispatchStr(shader_info, name)) 1009 1010 spv_bin_arrays = "\n".join(spv_bin_strs) 1011 register_shader_infos = "\n".join(register_shader_info_strs) 1012 shader_info_registry = "\n".join(shader_registry_strs) 1013 1014 cpp = cpp_template.format( 1015 spv_bin_arrays=spv_bin_arrays, 1016 register_shader_infos=register_shader_infos, 1017 shader_info_registry=shader_info_registry, 1018 ) 1019 1020 with open(cpp_src_file_path, "w") as fw: 1021 fw.write(cpp) 1022 1023 1024########## 1025# Main # 1026########## 1027 1028 1029def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: 1030 d = {} 1031 if items: 1032 for item in items: 1033 tokens = item.split("=") 1034 key = tokens[0].strip() 1035 value = tokens[1].strip() 1036 d[key] = value 1037 return d 1038 1039 1040def main(argv: List[str]) -> int: 1041 parser = argparse.ArgumentParser(description="") 1042 parser.add_argument( 1043 "-i", 1044 "--glsl-paths", 1045 nargs="+", 1046 help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"', 1047 default=["."], 1048 ) 1049 parser.add_argument("-c", "--glslc-path", required=True, help="") 1050 parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") 1051 parser.add_argument("-o", "--output-path", required=True, help="") 1052 parser.add_argument("--replace-u16vecn", action="store_true", default=False) 1053 parser.add_argument("--optimize_size", action="store_true", help="") 1054 parser.add_argument("--optimize", action="store_true", help="") 1055 parser.add_argument( 1056 "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs" 1057 ) 1058 options = parser.parse_args() 1059 1060 env = DEFAULT_ENV 1061 env.update(TYPE_MAPPINGS) 1062 env.update(UTILITY_FNS) 1063 1064 for key, value in parse_arg_env(options.env).items(): 1065 env[key] = value 1066 1067 if not os.path.exists(options.output_path): 1068 os.makedirs(options.output_path) 1069 1070 if not os.path.exists(options.tmp_dir_path): 1071 os.makedirs(options.tmp_dir_path) 1072 1073 glslc_flags = "" 1074 if options.optimize_size: 1075 glslc_flags += "-Os" 1076 elif options.optimize: 1077 glslc_flags += "-O" 1078 1079 shader_generator = SPVGenerator( 1080 options.glsl_paths, 1081 env, 1082 options.glslc_path, 1083 glslc_flags=glslc_flags, 1084 replace_u16vecn=options.replace_u16vecn, 1085 ) 1086 output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) 1087 1088 genCppFiles( 1089 output_spv_files, 1090 f"{options.output_path}/{CPP_H_NAME}", 1091 f"{options.output_path}/{CPP_SRC_NAME}", 1092 ) 1093 1094 return 0 1095 1096 1097def invoke_main() -> None: 1098 sys.exit(main(sys.argv)) 1099 1100 1101if __name__ == "__main__": 1102 invoke_main() # pragma: no cover 1103