1# Represents all kernels used by an Executorch model. 2# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure. 3 4from __future__ import annotations 5 6import itertools 7from collections import defaultdict, namedtuple 8from dataclasses import dataclass 9from enum import IntEnum 10 11from torchgen.model import ( 12 BackendIndex, 13 BackendMetadata, 14 DispatchKey, 15 NativeFunction, 16 NativeFunctionsGroup, 17 OperatorName, 18) 19from torchgen.utils import assert_never 20 21 22KERNEL_KEY_VERSION = 1 23 24 25# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen 26class ScalarType(IntEnum): 27 Byte = 0 28 Char = 1 29 Short = 2 30 Int = 3 31 Long = 4 32 Float = 6 33 Double = 7 34 Bool = 11 35 36 37ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"]) 38 39 40@dataclass(frozen=True) 41class ETKernelKeyOpArgMeta: 42 arg_name: str 43 dtype: str 44 # The order of the dimensions if entry is a Tensor 45 dim_order: tuple[int, ...] 46 47 def to_native_string(self) -> str: 48 dtype_str = ScalarType[self.dtype].value 49 dim_str = str(self.dim_order)[1:-1].replace(" ", "") 50 return f"{dtype_str};{dim_str}" 51 52 53@dataclass(frozen=True) 54class ETKernelKey: 55 # Field undefined is default = True 56 arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = () 57 58 # Indicator for this kernel being used as a catch all 59 default: bool = False 60 61 version: int = KERNEL_KEY_VERSION 62 63 @staticmethod 64 def gen_from_yaml( 65 args: dict[str, tuple[str, str]], 66 type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val 67 dim_order_alias_map: dict[str, list[int]], 68 ) -> list[ETKernelKey]: 69 """Generate ETKernelKeys from arg kernel specs 70 Multiple ETKernelKeys are returned due to dtype permutations from utilizing 71 type_alias_map (actualizing each potential type permutation as a KernelKey) 72 73 Args: 74 args: Mapping from argument name to kernel specs 75 Kernel specs are a tuple of (dtype, dim_order). 76 Currently tuple entries must be aliased via the alias map arguments 77 type_alias_map: Mapping from type alias to potential type enums 78 i.e { T0 : [Double, Int] } means T0 can be either Double or Int 79 Used for lookup by args 80 dim_order_alias_map: Mapping from alias to a list of dimension orders 81 Used for lookup by args 82 """ 83 # Cast to dim order to int 84 dim_order_alias_map = { 85 k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items() 86 } 87 kernel_keys = [] 88 89 # Get all used Dtype Alias 90 dtype_alias_used = set() 91 for type_alias, dim_order in args.values(): 92 # Enforce usage of alias initially 93 # TODO: Support inlined arguments 94 assert type_alias in type_alias_map, "Undefined type alias: " + str( 95 type_alias 96 ) 97 assert ( 98 dim_order in dim_order_alias_map 99 ), "Undefined dim_order alias: " + str(dim_order) 100 dtype_alias_used.add(type_alias) 101 102 # Generate all permutations of dtype alias values 103 alias_dtypes = [ 104 [(alias, dtype) for dtype in type_alias_map[alias]] 105 for alias in dtype_alias_used 106 ] 107 alias_permutations = [ 108 dict(permutation) for permutation in list(itertools.product(*alias_dtypes)) 109 ] 110 111 # Using each alias value permutation, generate kernel keys 112 op_arg_cache = {} 113 for permutation in alias_permutations: 114 arg_list = [] 115 for arg_name, arg_spec in args.items(): 116 dtype = permutation[arg_spec[0]] 117 dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment] 118 if ( 119 cache_key := (arg_name, dtype, tuple(dim_order)) 120 ) not in op_arg_cache: 121 op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type] 122 123 arg_list.append(op_arg_cache[cache_key]) 124 kernel_keys.append(ETKernelKey(tuple(arg_list))) 125 126 return kernel_keys 127 128 def to_native_string(self) -> str: 129 if self.default: 130 return "default" 131 return ( 132 "v" 133 + str(KERNEL_KEY_VERSION) 134 + "/" 135 + "|".join([arg.to_native_string() for arg in self.arg_meta]) 136 ) 137 138 139@dataclass(frozen=True) 140class ETKernelIndex: 141 index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] 142 143 def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool: 144 m = self.get_kernels(g) 145 return m is not None 146 147 def get_kernels( 148 self, g: NativeFunction | NativeFunctionsGroup 149 ) -> dict[ETKernelKey, BackendMetadata]: 150 if isinstance(g, NativeFunction): 151 f = g 152 elif isinstance(g, NativeFunctionsGroup): 153 f = g.functional 154 else: 155 assert_never(g) 156 if f.func.name not in self.index: 157 return {} 158 return self.index[f.func.name] 159 160 @staticmethod 161 def grow_from_backend_indices( 162 kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]], 163 backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 164 ) -> None: 165 for dk in backend_indices: 166 index = backend_indices[dk] 167 for op, backend_metadata in index.items(): 168 if op in kernel_index: 169 kernel_index[op][ETKernelKey(default=True)] = backend_metadata 170 else: 171 kernel_index[op] = {ETKernelKey(default=True): backend_metadata} 172 173 @staticmethod 174 def from_backend_indices( 175 backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] 176 ) -> ETKernelIndex: 177 kernel_index: dict[ 178 OperatorName, dict[ETKernelKey, BackendMetadata] 179 ] = defaultdict(dict) 180 ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) 181 return ETKernelIndex(kernel_index) 182 183 def grow( 184 self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] 185 ) -> ETKernelIndex: 186 ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) 187 return self 188 189 def _to_backend_index(self) -> BackendIndex: 190 """ 191 WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. 192 """ 193 index: dict[OperatorName, BackendMetadata] = {} 194 for op in self.index: 195 kernel_dict = self.index[op] 196 assert ( 197 len(kernel_dict.values()) == 1 198 ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" 199 index[op] = kernel_dict.get( 200 ETKernelKey(default=True), 201 BackendMetadata(kernel="", structured=False, cpp_namespace=""), 202 ) 203 return BackendIndex( 204 dispatch_key=DispatchKey.CPU, 205 use_out_as_primary=False, 206 device_guard=False, 207 external=False, 208 index=index, 209 ) 210 211 # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a 212 @staticmethod 213 def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex: 214 combined = defaultdict(dict, index_a.index.copy()) 215 216 for op, entry in index_b.index.items(): 217 for key, metadata in entry.items(): 218 combined[op][key] = metadata 219 220 return ETKernelIndex(combined) 221