1#!/usr/bin/env python3 2 3from __future__ import annotations 4 5import argparse 6import json 7import sys 8from typing import Any 9 10import yaml 11from gen_op_registration_allowlist import ( 12 canonical_name, 13 gen_transitive_closure, 14 load_op_dep_graph, 15) 16 17from torchgen.selective_build.operator import ( 18 merge_operator_dicts, 19 SelectiveBuildOperator, 20) 21from torchgen.selective_build.selector import merge_kernel_metadata 22 23 24# Generate YAML file containing the operators used for a specific PyTorch model. 25# ------------------------------------------------------------------------------ 26# 27# This binary is responsible for generating the model_operators.yaml file for 28# each model from a pt_operator_library() BUCK macro invocation. 29# 30# Output YAML file format: 31# ------------------------ 32# 33# <BEGIN FILE CONTENTS> 34# include_all_non_op_selectives: False 35# include_all_operators: False 36# debug_info: 37# - model1@v100 38# - model2@v50 39# operators: 40# aten::add: 41# is_root_operator: Yes 42# is_used_for_training: Yes 43# include_all_overloads: No 44# debug_info: 45# - model1@v100 46# - model2@v50 47# aten::add.int: 48# is_root_operator: No 49# is_used_for_training: No 50# include_all_overloads: Yes 51# kernel_metadata: 52# add_kernel: 53# - Int8 54# - UInt32 55# sub_kernel: 56# - Int16 57# - Float 58# <END FILE CONTENTS> 59# 60# There are a few main inputs to this application 61# ----------------------------------------------- 62# 63# 1. Inference Root Operators (--root-ops): Root operators (called directly 64# from TorchScript) used by inference use-cases. 65# 66# 2. Training Root Operators (--training-root-ops): Root operators used 67# by training use-cases. Currently, this list is the list of all operators 68# used by training, and not just the root operators. All Training ops are 69# also considered for inference, so these are merged into inference ops. 70# 71# 3. Operator Depencency Graph (--dep-graph-yaml-path): A path to the 72# operator dependency graph used to determine which operators depend on 73# which other operators for correct functioning. This is used for 74# generating the transitive closure of all the operators used by the 75# model based on the root operators when static selective build is used. 76# For tracing based selective build, we don't need to perform this 77# transitive cloure. 78# 79# 4. Model Metadata (--model-name, --model-versions, --model-assets, 80# --model-backends): Self-descriptive. These are used to tell this 81# script which model operator lists to fetch from the Model 82# Build Metadata YAML files. 83# 84# 5. Model YAML files (--models-yaml-path): These yaml files contains 85# (for each model/version/asset/backend) the set of used root and traced 86# operators. This is used to extract the actual set of operators 87# needed to be included in the build. 88# 89 90 91def canonical_opnames(opnames: list[str]) -> list[str]: 92 return [canonical_name(opname) for opname in opnames] 93 94 95def make_filter_from_options( 96 model_name: str, 97 model_versions: list[str], 98 model_assets: list[str] | None, 99 model_backends: list[str] | None, 100): 101 def is_model_included(model_info) -> bool: 102 model = model_info["model"] 103 if model["name"] != model_name: 104 return False 105 if str(model["version"]) not in model_versions: 106 return False 107 if model_assets is not None and model["asset"] not in model_assets: 108 return False 109 # TODO: Handle backend later 110 return True 111 112 return is_model_included 113 114 115# Returns if a the specified rule is a new or old style pt_operator_library 116def is_new_style_rule(model_name: str, model_versions: list[str] | None): 117 return model_name is not None and model_versions is not None 118 119 120# Verifies that specified model_name, and all specified versions and assets 121# appear in at least one model yaml. Throws if verification is failed, 122# returns None on success 123def verify_all_specified_present( 124 model_assets: list[str] | None, 125 model_versions: list[str], 126 selected_models_yaml: list[dict[str, Any]], 127 rule_name: str, 128 model_name: str, 129 new_style_rule: bool, 130) -> None: 131 def find_missing_items(model_items, key, selected_models_yaml): 132 missing_items = [] 133 if not new_style_rule or not model_items: 134 return missing_items 135 for item in model_items: 136 found = False 137 for model in selected_models_yaml: 138 if str(model["model"][key]) == item: 139 found = True 140 if not found: 141 missing_items.append(item) 142 return missing_items 143 144 missing_assets = find_missing_items(model_assets, "asset", selected_models_yaml) 145 missing_versions = find_missing_items( 146 model_versions, "version", selected_models_yaml 147 ) 148 149 if len(missing_versions) > 0 or len(missing_assets) > 0: # at least one is missing 150 name_warning = "" 151 if len(selected_models_yaml) == 0: 152 name_warning = ( 153 "WARNING: 0 yaml's were found for target rule. This could be because the " 154 + "provided model name: {name} is incorrect. Please check that field as well as " 155 + "the assets and versions." 156 ).format(name=model_name) 157 raise RuntimeError( 158 ( 159 "Error: From the pt_operator_library rule for Rule: {name}, at least one entry for the " 160 + "following fields was expected -- Model: {model_name} Expected Assets: {expected_assets}, Expected Versions: " 161 + "{expected_versions}. {name_warning} In all_mobile_models.yaml either no assets were on one of the " 162 + "specified versions, one of the specified assets was not present on any of the specified " 163 + "versions, or both. Assets not found: {missing_assets}, Versions not found: {missing_versions} " 164 + "For questions please ask in https://fb.workplace.com/groups/2148543255442743/" 165 ).format( 166 name=rule_name, 167 model_name=model_name, 168 expected_versions=model_versions, 169 expected_assets=model_assets 170 if model_assets 171 else "<All model assets present on specified versions>", 172 name_warning=name_warning, 173 missing_versions=missing_versions 174 if len(missing_versions) > 0 175 else "<All specified versions had at least one asset>", 176 missing_assets=missing_assets 177 if len(missing_assets) > 0 178 else "<All specified assets are present on at least 1 version>", 179 ) 180 ) 181 182 183# Uses the selected models configs and then combines them into one dictionary, 184# formats them as a string, and places the string into output as a top level debug_info 185def create_debug_info_from_selected_models( 186 output: dict[str, object], 187 selected_models: list[dict], 188 new_style_rule: bool, 189) -> None: 190 model_dict = { 191 "asset_info": {}, # maps asset name -> dict of asset metadata like hashes 192 "is_new_style_rule": new_style_rule, 193 } 194 195 for model in selected_models: 196 model_info = model["model"] 197 asset = model_info["asset"] 198 hash = model_info["md5_hash"] 199 200 asset_info = model_dict["asset_info"].setdefault(asset, {}) 201 202 asset_info.setdefault("md5_hash", []).append(hash) 203 204 # Will later be used in gen_oplist to generate the model/version/asset checking 205 output["debug_info"] = [json.dumps(model_dict)] 206 207 208def fill_output(output: dict[str, object], options: object) -> None: 209 """Populate the output dict with the information required to serialize 210 the YAML file used for selective build. 211 """ 212 dept_graph = load_op_dep_graph(options.dep_graph_yaml_path) 213 214 model_versions = ( 215 options.model_versions.split(",") if options.model_versions is not None else [] 216 ) 217 model_assets = ( 218 options.model_assets.split(",") if options.model_assets is not None else None 219 ) 220 221 all_models_yaml = [] 222 if options.models_yaml_path: 223 for yaml_path in options.models_yaml_path: 224 with open(yaml_path, "rb") as f: 225 all_models_yaml.append(yaml.safe_load(f)) 226 227 model_filter_func = make_filter_from_options( 228 options.model_name, model_versions, model_assets, options.model_backends 229 ) 230 231 selected_models_yaml = list(filter(model_filter_func, all_models_yaml)) 232 233 verify_all_specified_present( 234 model_assets=model_assets, 235 model_versions=model_versions, 236 selected_models_yaml=selected_models_yaml, 237 rule_name=options.rule_name, 238 model_name=options.model_name, 239 new_style_rule=is_new_style_rule(options.model_name, options.model_versions), 240 ) 241 242 create_debug_info_from_selected_models( 243 output, 244 selected_models_yaml, 245 is_new_style_rule(options.model_name, options.model_versions), 246 ) 247 248 # initialize variables for static build from the pt_operator_library rule 249 if options.root_ops is not None: 250 static_root_ops = set(filter(lambda x: len(x) > 0, options.root_ops.split(","))) 251 else: 252 static_root_ops = set() 253 254 static_training_root_ops = set( 255 filter( 256 lambda x: len(x) > 0, 257 (options.training_root_ops or "").split(","), 258 ) 259 ) 260 if len(static_training_root_ops) > 0: 261 static_root_ops = static_root_ops | static_training_root_ops 262 # end if 263 264 root_ops_unexpand = set() 265 traced_ops = set() 266 training_root_ops_unexpand = set() 267 traced_training_ops = set() 268 all_kernel_metadata = [] 269 all_custom_classes = set() 270 all_build_features = set() 271 272 # Go through each yaml file and retrieve operator information. 273 for model_info in selected_models_yaml: 274 if "traced_operators" not in model_info: 275 # If this YAML file doesn't specify any traced operators, then it is using 276 # the static analysis selective build approach of finding transitively 277 # used operators, and we should update root_ops with the set of root 278 # operators, all of whose overloads must be included. In addition, these 279 # root_ops will be further expanded using the transitive closure of 280 # operator dependencies. 281 static_root_ops = static_root_ops | set(model_info["root_operators"]) 282 else: 283 # If this YAML file specifies traced operators, then it is using 284 # the tracing based selective build approach of finding used 285 # operators, and we should update root_ops_unexpand with the set of root 286 # operators whose overloads don't need to be included. In addition, these 287 # root_ops_unexpand will NOT be further expanded. If the train flag is 288 # set then the ops will be used for training, so we put them in a separate 289 # set 290 if model_info["train"]: 291 training_root_ops_unexpand = training_root_ops_unexpand | set( 292 model_info["root_operators"] 293 ) 294 traced_training_ops = traced_training_ops | set( 295 model_info["traced_operators"] 296 ) 297 else: 298 root_ops_unexpand = root_ops_unexpand | set( 299 model_info["root_operators"] 300 ) 301 traced_ops = traced_ops | set(model_info["traced_operators"]) 302 303 if "kernel_metadata" in model_info: 304 all_kernel_metadata.append(model_info["kernel_metadata"]) 305 306 if "custom_classes" in model_info: 307 all_custom_classes = all_custom_classes | set(model_info["custom_classes"]) 308 309 if "build_features" in model_info: 310 all_build_features = all_build_features | set(model_info["build_features"]) 311 312 # This following section on transitive closure is relevant to static build only 313 canonical_root_ops = canonical_opnames(static_root_ops) 314 # If no canonical_root_ops exist, don't compute the transitive closure 315 # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required 316 # for inference. 317 if len(canonical_root_ops) > 0: 318 closure_op_list = gen_transitive_closure(dept_graph, canonical_root_ops) 319 else: 320 closure_op_list = set() 321 322 canonical_training_root_ops = canonical_opnames(static_training_root_ops) 323 # If no canonical_training_root_ops exist, don't compute the transitive closure 324 # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required 325 # for training. 326 if len(canonical_training_root_ops) > 0: 327 closure_training_op_list = gen_transitive_closure( 328 dept_graph, canonical_training_root_ops, train=True 329 ) 330 else: 331 closure_training_op_list = set() 332 333 # bucketed_ops holds sets of operators that correspond to specific semantic buckets. For 334 # example: 335 # 336 # 1. Root Operators not used for training w/o full overload inclusion 337 # 2. Root Operators not used for training w/ full overload inclusion 338 # 3. Root Operators used for training w/o full overload inclusion 339 # 4. Root Operators used for training w/ full overload inclusion 340 # 5. Non-root Operators not used for training w/o full overload inclusion 341 # etc... 342 # 343 # Basically for each of the 3 boolean conditional, there are 2 344 # options (True/False). 345 # 346 bucketed_ops = [] 347 348 # START STATIC BUILD OPS 349 static_root_ops_bucket = {} 350 for op_name in static_root_ops: 351 op = SelectiveBuildOperator.from_yaml_dict( 352 op_name, 353 { 354 "is_root_operator": True, 355 "is_used_for_training": False, 356 "include_all_overloads": not options.not_include_all_overloads_static_root_ops, 357 "debug_info": [options.model_name], 358 }, 359 ) 360 static_root_ops_bucket[op_name] = op 361 bucketed_ops.append(static_root_ops_bucket) 362 363 closure_ops_bucket = {} 364 for op_name in closure_op_list: 365 op = SelectiveBuildOperator.from_yaml_dict( 366 op_name, 367 { 368 "is_root_operator": False, 369 "is_used_for_training": False, 370 "include_all_overloads": not options.not_include_all_overloads_closure_ops, 371 "debug_info": [options.model_name], 372 }, 373 ) 374 closure_ops_bucket[op_name] = op 375 bucketed_ops.append(closure_ops_bucket) 376 377 static_training_root_ops_bucket = {} 378 for op_name in static_training_root_ops: 379 op = SelectiveBuildOperator.from_yaml_dict( 380 op_name, 381 { 382 "is_root_operator": True, 383 "is_used_for_training": True, 384 "include_all_overloads": True, 385 "debug_info": [options.model_name], 386 }, 387 ) 388 static_training_root_ops_bucket[op_name] = op 389 bucketed_ops.append(static_training_root_ops_bucket) 390 391 closure_training_ops_bucket = {} 392 for op_name in closure_training_op_list: 393 op = SelectiveBuildOperator.from_yaml_dict( 394 op_name, 395 { 396 "is_root_operator": False, 397 "is_used_for_training": True, 398 "include_all_overloads": True, 399 "debug_info": [options.model_name], 400 }, 401 ) 402 closure_training_ops_bucket[op_name] = op 403 bucketed_ops.append(closure_training_ops_bucket) 404 # END STATIC BUILD OPS 405 406 # START TRACING BASED BUILD OPS 407 root_ops_unexpand_bucket = {} 408 for op_name in root_ops_unexpand: 409 op = SelectiveBuildOperator.from_yaml_dict( 410 op_name, 411 { 412 "is_root_operator": True, 413 "is_used_for_training": False, 414 "include_all_overloads": False, 415 "debug_info": [options.model_name], 416 }, 417 ) 418 root_ops_unexpand_bucket[op_name] = op 419 bucketed_ops.append(root_ops_unexpand_bucket) 420 421 traced_ops_bucket = {} 422 for op_name in traced_ops: 423 op = SelectiveBuildOperator.from_yaml_dict( 424 op_name, 425 { 426 "is_root_operator": False, 427 "is_used_for_training": False, 428 "include_all_overloads": False, 429 "debug_info": [options.model_name], 430 }, 431 ) 432 traced_ops_bucket[op_name] = op 433 bucketed_ops.append(traced_ops_bucket) 434 435 training_root_ops_unexpand_bucket = {} 436 for op_name in training_root_ops_unexpand: 437 op = SelectiveBuildOperator.from_yaml_dict( 438 op_name, 439 { 440 "is_root_operator": True, 441 "is_used_for_training": True, 442 "include_all_overloads": False, 443 "debug_info": [options.model_name], 444 }, 445 ) 446 training_root_ops_unexpand_bucket[op_name] = op 447 bucketed_ops.append(training_root_ops_unexpand_bucket) 448 449 traced_training_ops_bucket = {} 450 for op_name in traced_training_ops: 451 op = SelectiveBuildOperator.from_yaml_dict( 452 op_name, 453 { 454 "is_root_operator": False, 455 "is_used_for_training": True, 456 "include_all_overloads": False, 457 "debug_info": [options.model_name], 458 }, 459 ) 460 traced_training_ops_bucket[op_name] = op 461 bucketed_ops.append(traced_training_ops_bucket) 462 # END TRACING BASED BUILD OPS 463 464 # Merge dictionaries together to remove op duplication 465 operators: dict[str, SelectiveBuildOperator] = {} 466 for ops_dict in bucketed_ops: 467 operators = merge_operator_dicts(operators, ops_dict) 468 469 # Loop over all operators, and if any of the them specifies that 470 # all overloads need to be included, then set include_all_non_op_selectives 471 # to True, since it indicates that this operator list came from something 472 # other than a traced operator list. 473 include_all_non_op_selectives = False 474 for op_name, op_info in operators.items(): 475 include_all_non_op_selectives = ( 476 include_all_non_op_selectives or op_info.include_all_overloads 477 ) 478 479 operators_as_dict = {} 480 for k, v in operators.items(): 481 operators_as_dict[k] = v.to_dict() 482 483 output["operators"] = operators_as_dict 484 485 output["custom_classes"] = all_custom_classes 486 487 output["build_features"] = all_build_features 488 489 output["include_all_non_op_selectives"] = include_all_non_op_selectives 490 if len(all_kernel_metadata) > 0: 491 kernel_metadata = {} 492 for kt in all_kernel_metadata: 493 kernel_metadata = merge_kernel_metadata(kernel_metadata, kt) 494 output["kernel_metadata"] = kernel_metadata 495 496 497def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 498 parser.add_argument( 499 "--root-ops", 500 "--root_ops", 501 help="A comma separated list of root operators used by the model", 502 required=False, 503 ) 504 parser.add_argument( 505 "--training-root-ops", 506 "--training_root_ops", 507 help="A comma separated list of root operators used for training", 508 required=False, 509 ) 510 parser.add_argument( 511 "--output-path", 512 "--output_path", 513 help="The location of the output yaml file.", 514 required=True, 515 ) 516 parser.add_argument( 517 "--dep-graph-yaml-path", 518 "--dep_graph_yaml_path", 519 type=str, 520 help="A path to the Operator Dependency Graph YAML file.", 521 required=True, 522 ) 523 parser.add_argument( 524 "--model-name", 525 "--model_name", 526 type=str, 527 help="The name of the model that uses the specified root operators.", 528 required=True, 529 ) 530 parser.add_argument( 531 "--model-versions", 532 "--model_versions", 533 type=str, 534 help="A comma separated list of model versions.", 535 required=False, 536 ) 537 parser.add_argument( 538 "--model-assets", 539 "--model_assets", 540 type=str, 541 help="A comma separate list of model asset names (if absent, defaults to all assets for this model).", 542 required=False, 543 ) 544 parser.add_argument( 545 "--model-backends", 546 "--model_backends", 547 type=str, 548 default="CPU", 549 help="A comma separated list of model backends.", 550 required=False, 551 ) 552 parser.add_argument( 553 "--models-yaml-path", 554 "--models_yaml_path", 555 type=str, 556 help="The paths to the mobile model config YAML files.", 557 required=False, 558 nargs="+", 559 ) 560 parser.add_argument( 561 "--include-all-operators", 562 "--include_all_operators", 563 action="store_true", 564 default=False, 565 help="Set this flag to request inclusion of all operators (i.e. build is not selective).", 566 required=False, 567 ) 568 parser.add_argument( 569 "--rule-name", 570 "--rule_name", 571 type=str, 572 help="The name of pt_operator_library rule resulting in this generation", 573 required=True, 574 ) 575 parser.add_argument( 576 "--not-include-all-overloads-static-root-ops", 577 "--not_include_all_overloads_static_root_ops", 578 action="store_true", 579 default=False, 580 help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine", 581 required=False, 582 ) 583 parser.add_argument( 584 "--not-include-all-overloads-closure-ops", 585 "--not_include_all_overloads_closure_ops", 586 action="store_true", 587 default=False, 588 help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine", 589 required=False, 590 ) 591 return parser 592 593 594def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace: 595 return parser.parse_args() 596 597 598def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace: 599 parser = add_arguments_parser(parser) 600 return parse_options(parser) 601 602 603def main(argv) -> None: 604 parser = argparse.ArgumentParser(description="Generate used operators YAML") 605 options = get_parser_options(parser) 606 607 model_dict = { 608 "model_name": options.model_name, 609 "asset_info": {}, 610 "is_new_style_rule": False, 611 } 612 output = { 613 "debug_info": [json.dumps(model_dict)], 614 } 615 616 if options.include_all_operators: 617 output["include_all_operators"] = True 618 output["operators"] = {} 619 output["kernel_metadata"] = {} 620 else: 621 fill_output(output, options) 622 623 with open(options.output_path, "wb") as out_file: 624 out_file.write( 625 yaml.safe_dump( 626 output, 627 default_flow_style=False, 628 ).encode("utf-8") 629 ) 630 631 632if __name__ == "__main__": 633 sys.exit(main(sys.argv)) 634