1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10import argparse 11import os 12import sys 13from typing import Any, List 14 15import yaml 16 17from torchgen.code_template import CodeTemplate 18 19 20ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))""" 21ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str) 22 23selected_kernel_dtypes_h_template_str = """#pragma once 24/** 25 * Generated by executorch/codegen/tools/gen_selected_op_variants.py 26 */ 27 28inline constexpr bool should_include_kernel_dtype( 29 const char *operator_name, 30 exec_aten::ScalarType scalar_type 31) { 32 return $body; 33} 34""" 35selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) 36 37# enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h 38dtype_enum_to_type = { 39 "0": "Byte", 40 "1": "Char", 41 "2": "Short", 42 "3": "Int", 43 "4": "Long", 44 "5": "Half", 45 "6": "Float", 46 "7": "Double", 47 "8": "ComplexHalf", 48 "9": "ComplexFloat", 49 "10": "ComplexDouble", 50 "11": "Bool", 51 "12": "QInt8", 52 "13": "QUInt8", 53 "14": "QInt32", 54 "15": "BFloat16", 55 "16": "QUInt4x2", 56 "17": "QUInt2x4", 57 "18": "Bits1x8", 58 "19": "Bits2x4", 59 "20": "Bits4x2", 60 "21": "Bits8", 61 "22": "Bits16", 62} 63 64 65def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None: 66 with open(yaml_file_path, "r") as selected_operators_file: 67 # Collect et_kernel_metadata from selected_operators.yaml and extract dtypes 68 # Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1 69 selected_operators_dict = yaml.safe_load(selected_operators_file) 70 et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {}) 71 assert isinstance(et_kernel_metadata, dict) 72 body = "true" 73 body_parts = [] 74 for operator_name, kernel_metadata_str in et_kernel_metadata.items(): 75 tensor_meta = [] 76 for kernel_metadata in kernel_metadata_str: 77 if kernel_metadata == "default" or "/" not in kernel_metadata: 78 break 79 else: 80 x = kernel_metadata.split("/")[1] 81 tensor_meta.extend(x.split("|")) 82 conditions = ["true"] 83 if len(tensor_meta) > 0: 84 dtype_set = set([x.split(";")[0] for x in tensor_meta]) 85 dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set]) 86 conditions = [ 87 "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list 88 ] 89 body_parts.append( 90 ops_and_dtypes_template.substitute( 91 operator_name=operator_name.replace("aten::", ""), 92 dtype_checks=" || ".join(conditions), 93 ), 94 ) 95 body = "\n || ".join(body_parts) 96 header_contents = selected_kernel_dtypes_h_template.substitute(body=body) 97 selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h") 98 with open(selected_op_variants_path, "wb") as out_file: 99 out_file.write(header_contents.encode("utf-8")) 100 101 102def main(argv: List[Any]) -> None: 103 parser = argparse.ArgumentParser(description="Generate operator lists") 104 parser.add_argument( 105 "--yaml-file-path", 106 "--yaml_file_path", 107 help=("The directory where selected_operators.yaml was generated)"), 108 required=True, 109 ) 110 parser.add_argument( 111 "--output-dir", 112 "--output_dir", 113 help=( 114 "The directory to store the output yaml files (selected_op_variants.h, " 115 + "selected_kernel_dtypes.h, selected_operators.yaml)" 116 ), 117 required=True, 118 ) 119 120 options = parser.parse_args(argv) 121 write_selected_op_variants(options.yaml_file_path, options.output_dir) 122 123 124if __name__ == "__main__": 125 main(sys.argv[1:]) 126