1from __future__ import annotations 2 3from collections import defaultdict, namedtuple 4from typing import Any 5 6import yaml 7 8from torchgen.executorch.model import ETKernelIndex, ETKernelKey 9from torchgen.gen import LineLoader, parse_native_yaml 10from torchgen.model import ( 11 BackendMetadata, 12 DispatchKey, 13 FunctionSchema, 14 NativeFunction, 15 OperatorName, 16) 17from torchgen.utils import NamespaceHelper 18 19 20# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. 21ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) 22 23# Fields in native_functions.yaml used to determine which kernels should be used 24ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] 25 26 27def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]: 28 """Given a loaded yaml representing kernel assignment information, extract the 29 mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) 30 31 Args: 32 ei: Dict keys {kernels, type_alias, dim_order_alias} 33 See ETKernelKey for description of arguments 34 """ 35 e = ei.copy() 36 if (kernels := e.pop("kernels", None)) is None: 37 return {} 38 39 type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment] 40 dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] 41 dim_order_alias.pop("__line__", None) 42 43 kernel_mapping: dict[ETKernelKey, BackendMetadata] = {} 44 45 for entry in kernels: # type: ignore[attr-defined] 46 arg_meta = entry.get("arg_meta") 47 if arg_meta is not None: 48 arg_meta.pop("__line__") 49 50 kernel_name = entry.get("kernel_name") 51 namespace_helper = NamespaceHelper.from_namespaced_entity( 52 kernel_name, max_level=3 53 ) 54 kernel_namespace = namespace_helper.get_cpp_namespace(default="at") 55 backend_metadata = BackendMetadata( 56 kernel=namespace_helper.entity_name, 57 structured=False, 58 cpp_namespace=(kernel_namespace + "::native"), 59 ) 60 61 kernel_keys = ( 62 [ETKernelKey((), default=True)] 63 if arg_meta is None 64 else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type] 65 ) 66 67 for kernel_key in kernel_keys: 68 assert kernel_key not in kernel_mapping, ( 69 "Duplicate kernel key: " + str(kernel_key) + " " + str(e) 70 ) 71 kernel_mapping[kernel_key] = backend_metadata 72 73 return kernel_mapping 74 75 76def parse_et_yaml_struct(es: object) -> ETKernelIndex: 77 """Given a loaded yaml representing a list of operators, for each op extract the mapping 78 of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance 79 that should be used by the kernel key). 80 """ 81 indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {} 82 for ei in es: # type: ignore[attr-defined] 83 e = ei.copy() 84 85 funcs = e.pop("func") 86 assert isinstance(funcs, str), f"not a str: {funcs}" 87 namespace_helper = NamespaceHelper.from_namespaced_entity( 88 namespaced_entity=funcs, max_level=1 89 ) 90 opname = FunctionSchema.parse(namespace_helper.entity_name).name 91 92 assert opname not in indices, f"Duplicate func found in yaml: {opname} already" 93 94 if len(index := parse_from_yaml(e)) != 0: 95 indices[opname] = index 96 97 return ETKernelIndex(indices) 98 99 100def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: 101 """Given a loaded yaml representing a list of operators, extract the 102 kernel key related fields indexed by the operator name. 103 """ 104 fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict) 105 for ei in es: # type: ignore[attr-defined] 106 funcs = ei.get("func") 107 assert isinstance(funcs, str), f"not a str: {funcs}" 108 namespace_helper = NamespaceHelper.from_namespaced_entity( 109 namespaced_entity=funcs, max_level=1 110 ) 111 opname = FunctionSchema.parse(namespace_helper.entity_name).name 112 113 for field in ET_FIELDS: 114 if (value := ei.get(field)) is not None: 115 fields[opname][field] = value 116 117 return fields 118 119 120def parse_et_yaml( 121 path: str, 122 tags_yaml_path: str, 123 ignore_keys: set[DispatchKey] | None = None, 124 skip_native_fns_gen: bool = False, 125) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]: 126 """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict 127 of fields to persist from native_functions.yaml to functions.yaml 128 """ 129 with open(path) as f: 130 es = yaml.load(f, Loader=LineLoader) 131 132 et_kernel = extract_kernel_fields(es) 133 134 # Remove ET specific fields from entries for BC compatibility 135 strip_et_fields(es) 136 137 native_yaml = parse_native_yaml( 138 path, 139 tags_yaml_path, 140 ignore_keys, 141 skip_native_fns_gen=skip_native_fns_gen, 142 loaded_yaml=es, 143 ) 144 return native_yaml.native_functions, et_kernel 145 146 147def strip_et_fields(es: object) -> None: 148 """Given a loaded yaml representing a list of operators, 149 remove ET specific fields from every entries for BC compatibility 150 """ 151 for entry in es: # type: ignore[attr-defined] 152 for field in ET_FIELDS: 153 entry.pop(field, None) 154