1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import json 9import os 10import sys 11from enum import IntEnum 12from typing import Any, Dict, List, Optional, Set 13 14import yaml 15from torchgen.executorch.parse import strip_et_fields 16 17from torchgen.gen import LineLoader, parse_native_yaml_struct 18from torchgen.selective_build.operator import SelectiveBuildOperator 19from torchgen.selective_build.selector import merge_et_kernel_metadata 20 21# Output YAML file format: 22# ------------------------ 23# 24# <BEGIN FILE CONTENTS> 25# include_all_non_op_selectives: False 26# include_all_operators: False 27# debug_info: 28# - model1@v100 29# - model2@v50 30# operators: 31# aten::add: 32# is_root_operator: Yes 33# is_used_for_training: Yes 34# include_all_overloads: No 35# debug_info: 36# - model1@v100 37# - model2@v50 38# aten::add.int: 39# is_root_operator: No 40# is_used_for_training: No 41# include_all_overloads: Yes 42# et_kernel_metadata: 43# aten::add.out: 44# # A list of different kernel keys (tensors with dtype-enum/dim-order) combinations used in model 45# - v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1 46# - v1/3;0,1|3;0,1|3;0,1|3;0,1 # Int, 0, 1 47# aten::mul.out: 48# - v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1 49# <END FILE CONTENTS> 50 51 52class ScalarType(IntEnum): 53 Byte = 0 54 Char = 1 55 Short = 2 56 Int = 3 57 Long = 4 58 Float = 6 59 Double = 7 60 Bool = 11 61 # TODO(jakeszwe): Verify these are unused and then remove support 62 QInt8 = 12 63 QUInt8 = 13 64 QInt32 = 14 65 QUInt4X2 = 16 66 QUInt2X4 = 17 67 # Types currently not implemented. 68 # Half = 5 69 # ComplexHalf = 8 70 # ComplexFloat = 9 71 # ComplexDouble = 10 72 # BFloat16 = 15 73 74 75class KernelType(IntEnum): 76 TENSOR = 5 77 TENSOR_LIST = 10 78 OPTIONAL_TENSOR_LIST = 11 79 80 81def _get_operators(model_file: str) -> List[str]: 82 from executorch.codegen.tools.selective_build import ( 83 _get_program_from_buffer, 84 _get_program_operators, 85 ) 86 87 print("Processing model file: ", model_file) 88 with open(model_file, "rb") as f: 89 buf = f.read() 90 91 program = _get_program_from_buffer(buf) 92 operators = _get_program_operators(program) 93 print(f"Model file loaded, operators are: {operators}") 94 return operators 95 96 97def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]: 98 99 from executorch.codegen.tools.selective_build import ( 100 _get_io_metadata_for_program_operators, 101 _get_program_from_buffer, 102 _IOMetaData, 103 ) 104 105 with open(model_file, "rb") as f: 106 buf = f.read() 107 108 program = _get_program_from_buffer(buf) 109 operators_with_io_metadata = _get_io_metadata_for_program_operators(program) 110 111 op_kernel_key_list: Dict[str, List[str]] = {} 112 113 specialized_kernels: Set[List[_IOMetaData]] 114 for op_name, specialized_kernels in operators_with_io_metadata.items(): 115 print(op_name) 116 if op_name not in op_kernel_key_list: 117 op_kernel_key_list[op_name] = [] 118 119 for specialized_kernel in specialized_kernels: 120 version = "v1" 121 kernel_key = version + "/" 122 for io_metadata in specialized_kernel: 123 if io_metadata.kernel_type in [ 124 KernelType.TENSOR, 125 KernelType.TENSOR_LIST, 126 KernelType.OPTIONAL_TENSOR_LIST, 127 ]: 128 dim_order = ",".join(map(str, io_metadata.dim_order)) 129 kernel_key += f"{io_metadata.dtype};{dim_order}|" 130 op_kernel_key_list[op_name].append(kernel_key[:-1]) 131 132 return op_kernel_key_list 133 134 135def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[str]]: 136 ops = [] 137 with open(ops_yaml_path, "r") as f: 138 es = yaml.load(f, Loader=LineLoader) 139 func_entries = [] 140 for e in es: 141 if "op" in e: 142 ops.append(("aten::" if "::" not in e.get("op") else "") + e.get("op")) 143 else: 144 func_entries.append(e) 145 strip_et_fields(es) 146 parsed_yaml = parse_native_yaml_struct( 147 func_entries, set(), None, path=ops_yaml_path, skip_native_fns_gen=True 148 ) 149 ops.extend([f"{f.namespace}::{f.func.name}" for f in parsed_yaml.native_functions]) 150 # TODO (larryliu): accept the new op yaml syntax 151 return {op: ["default"] for op in ops} 152 153 154def _dump_yaml( 155 op_list: List[str], 156 output_path: str, 157 model_name: Optional[str] = None, 158 et_kernel_metadata: Optional[Dict[str, List[str]]] = None, 159 include_all_operators: bool = False, 160): 161 # no debug info yet 162 output = {} 163 operators: Dict[str, Dict[str, object]] = {} 164 for op_name in op_list: 165 op = SelectiveBuildOperator.from_yaml_dict( 166 op_name, 167 { 168 "is_root_operator": True, 169 "is_used_for_training": True, 170 "include_all_overloads": False, 171 "debug_info": [model_name], 172 }, 173 ) 174 operators[op_name] = op.to_dict() 175 176 output["operators"] = operators 177 output["custom_classes"] = [] 178 output["build_features"] = [] 179 output["include_all_non_op_selectives"] = False 180 output["include_all_operators"] = include_all_operators 181 output["kernel_metadata"] = {} 182 output["et_kernel_metadata"] = et_kernel_metadata 183 with open(output_path, "wb") as out_file: 184 out_file.write( 185 yaml.safe_dump( 186 output, 187 default_flow_style=False, 188 ).encode("utf-8") 189 ) 190 191 192def gen_oplist( 193 output_path: str, 194 model_file_path: Optional[str] = None, 195 ops_schema_yaml_path: Optional[str] = None, 196 root_ops: Optional[str] = None, 197 ops_dict: Optional[str] = None, 198 include_all_operators: bool = False, 199): 200 assert ( 201 model_file_path 202 or ops_schema_yaml_path 203 or root_ops 204 or ops_dict 205 or include_all_operators 206 ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators." 207 208 assert output_path, "Need to provide output_path for dumped yaml file." 209 op_set = set() 210 source_name = None 211 et_kernel_metadata = {} 212 if root_ops: 213 # decide delimiter 214 delimiter = "," if "," in root_ops else " " 215 print(root_ops) 216 op_set.update( 217 set(filter(lambda x: len(x) > 0, map(str.strip, root_ops.split(delimiter)))) 218 ) 219 et_kernel_metadata = merge_et_kernel_metadata( 220 et_kernel_metadata, {op: ["default"] for op in op_set} 221 ) 222 if ops_dict: 223 ops_and_metadata = json.loads(ops_dict) 224 for op, metadata in ops_and_metadata.items(): 225 op_set.update({op}) 226 op_metadata = metadata if len(metadata) > 0 else ["default"] 227 et_kernel_metadata = merge_et_kernel_metadata( 228 et_kernel_metadata, {op: op_metadata} 229 ) 230 if model_file_path: 231 assert os.path.isfile( 232 model_file_path 233 ), f"The value for --model_file_path needs to be a valid file, got {model_file_path}" 234 op_set.update(_get_operators(model_file_path)) 235 source_name = model_file_path 236 et_kernel_metadata = merge_et_kernel_metadata( 237 et_kernel_metadata, _get_kernel_metadata_for_model(model_file_path) 238 ) 239 if ops_schema_yaml_path: 240 assert os.path.isfile( 241 ops_schema_yaml_path 242 ), f"The value for --ops_schema_yaml_path needs to be a valid file, got {ops_schema_yaml_path}" 243 et_kernel_metadata = merge_et_kernel_metadata( 244 et_kernel_metadata, 245 _get_et_kernel_metadata_from_ops_yaml(ops_schema_yaml_path), 246 ) 247 op_set.update(et_kernel_metadata.keys()) 248 source_name = ops_schema_yaml_path 249 _dump_yaml( 250 sorted(op_set), 251 output_path, 252 os.path.basename(source_name) if source_name else None, 253 et_kernel_metadata, 254 include_all_operators, 255 ) 256 257 258def main(args: List[Any]) -> None: 259 """This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py. 260 It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so 261 it can be used in gen.py. 262 """ 263 parser = argparse.ArgumentParser( 264 description="Generate operator list from a model file" 265 ) 266 parser.add_argument( 267 "--output_path", 268 help=("The path to the output yaml file (selected_operators.yaml)"), 269 required=True, 270 ) 271 parser.add_argument( 272 "--model_file_path", 273 help=("Path to an executorch program"), 274 required=False, 275 ) 276 parser.add_argument( 277 "--ops_schema_yaml_path", 278 help=("Dump operator names from operator schema yaml path"), 279 required=False, 280 ) 281 parser.add_argument( 282 "--root_ops", 283 help=("A comma separated list of root operators used by the model"), 284 required=False, 285 ) 286 parser.add_argument( 287 "--ops_dict", 288 help=( 289 "A json object containing operators and their associated dtype and dim order" 290 ), 291 required=False, 292 ) 293 parser.add_argument( 294 "--include-all-operators", 295 "--include_all_operators", 296 action="store_true", 297 default=False, 298 help="Set this flag to request inclusion of all operators (i.e. build is not selective).", 299 required=False, 300 ) 301 options = parser.parse_args(args) 302 303 try: 304 gen_oplist( 305 output_path=options.output_path, 306 model_file_path=options.model_file_path, 307 ops_schema_yaml_path=options.ops_schema_yaml_path, 308 root_ops=options.root_ops, 309 ops_dict=options.ops_dict, 310 include_all_operators=options.include_all_operators, 311 ) 312 except Exception as e: 313 command = ["python codegen/tools/gen_oplist.py"] 314 if options.model_file_path: 315 command.append(f"--model_file_path {options.model_file_path}") 316 if options.ops_schema_yaml_path: 317 command.append(f"--ops_schema_yaml_path {options.ops_schema_yaml_path}") 318 if options.root_ops: 319 command.append(f"--root_ops {options.root_ops}") 320 if options.ops_dict: 321 command.append(f"--ops_dict {options.ops_dict}") 322 if options.include_all_operators: 323 command.append("--include-all-operators") 324 repro_command = " ".join(command) 325 raise RuntimeError( 326 f"""Failed to generate selected_operators.yaml. Repro command: 327 {repro_command} 328 """ 329 ) from e 330 331 332if __name__ == "__main__": 333 main(sys.argv[1:]) 334